diff options
Diffstat (limited to 'llvm/lib')
253 files changed, 7187 insertions, 4148 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 88533f2..031d675 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -45,8 +45,10 @@ #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsARM.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/IR/IntrinsicsX86.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -1687,6 +1689,58 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) { case Intrinsic::x86_avx512_cvttsd2usi64: return !Call->isStrictFP(); + // NVVM float/double to int32/uint32 conversion intrinsics + case Intrinsic::nvvm_f2i_rm: + case Intrinsic::nvvm_f2i_rn: + case Intrinsic::nvvm_f2i_rp: + case Intrinsic::nvvm_f2i_rz: + case Intrinsic::nvvm_f2i_rm_ftz: + case Intrinsic::nvvm_f2i_rn_ftz: + case Intrinsic::nvvm_f2i_rp_ftz: + case Intrinsic::nvvm_f2i_rz_ftz: + case Intrinsic::nvvm_f2ui_rm: + case Intrinsic::nvvm_f2ui_rn: + case Intrinsic::nvvm_f2ui_rp: + case Intrinsic::nvvm_f2ui_rz: + case Intrinsic::nvvm_f2ui_rm_ftz: + case Intrinsic::nvvm_f2ui_rn_ftz: + case Intrinsic::nvvm_f2ui_rp_ftz: + case Intrinsic::nvvm_f2ui_rz_ftz: + case Intrinsic::nvvm_d2i_rm: + case Intrinsic::nvvm_d2i_rn: + case Intrinsic::nvvm_d2i_rp: + case Intrinsic::nvvm_d2i_rz: + case Intrinsic::nvvm_d2ui_rm: + case Intrinsic::nvvm_d2ui_rn: + case Intrinsic::nvvm_d2ui_rp: + case Intrinsic::nvvm_d2ui_rz: + + // NVVM float/double to int64/uint64 conversion intrinsics + case Intrinsic::nvvm_f2ll_rm: + case Intrinsic::nvvm_f2ll_rn: + case Intrinsic::nvvm_f2ll_rp: + case Intrinsic::nvvm_f2ll_rz: + case Intrinsic::nvvm_f2ll_rm_ftz: + case Intrinsic::nvvm_f2ll_rn_ftz: + case Intrinsic::nvvm_f2ll_rp_ftz: + case Intrinsic::nvvm_f2ll_rz_ftz: + case Intrinsic::nvvm_f2ull_rm: + case Intrinsic::nvvm_f2ull_rn: + case Intrinsic::nvvm_f2ull_rp: + case Intrinsic::nvvm_f2ull_rz: + case Intrinsic::nvvm_f2ull_rm_ftz: + case Intrinsic::nvvm_f2ull_rn_ftz: + case Intrinsic::nvvm_f2ull_rp_ftz: + case Intrinsic::nvvm_f2ull_rz_ftz: + case Intrinsic::nvvm_d2ll_rm: + case Intrinsic::nvvm_d2ll_rn: + case Intrinsic::nvvm_d2ll_rp: + case Intrinsic::nvvm_d2ll_rz: + case Intrinsic::nvvm_d2ull_rm: + case Intrinsic::nvvm_d2ull_rn: + case Intrinsic::nvvm_d2ull_rp: + case Intrinsic::nvvm_d2ull_rz: + // Sign operations are actually bitwise operations, they do not raise // exceptions even for SNANs. case Intrinsic::fabs: @@ -1849,6 +1903,12 @@ inline bool llvm_fenv_testexcept() { return false; } +static const APFloat FTZPreserveSign(const APFloat &V) { + if (V.isDenormal()) + return APFloat::getZero(V.getSemantics(), V.isNegative()); + return V; +} + Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty) { llvm_fenv_clearexcept(); @@ -2309,6 +2369,85 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, return ConstantFP::get(Ty->getContext(), U); } + // NVVM float/double to signed/unsigned int32/int64 conversions: + switch (IntrinsicID) { + // f2i + case Intrinsic::nvvm_f2i_rm: + case Intrinsic::nvvm_f2i_rn: + case Intrinsic::nvvm_f2i_rp: + case Intrinsic::nvvm_f2i_rz: + case Intrinsic::nvvm_f2i_rm_ftz: + case Intrinsic::nvvm_f2i_rn_ftz: + case Intrinsic::nvvm_f2i_rp_ftz: + case Intrinsic::nvvm_f2i_rz_ftz: + // f2ui + case Intrinsic::nvvm_f2ui_rm: + case Intrinsic::nvvm_f2ui_rn: + case Intrinsic::nvvm_f2ui_rp: + case Intrinsic::nvvm_f2ui_rz: + case Intrinsic::nvvm_f2ui_rm_ftz: + case Intrinsic::nvvm_f2ui_rn_ftz: + case Intrinsic::nvvm_f2ui_rp_ftz: + case Intrinsic::nvvm_f2ui_rz_ftz: + // d2i + case Intrinsic::nvvm_d2i_rm: + case Intrinsic::nvvm_d2i_rn: + case Intrinsic::nvvm_d2i_rp: + case Intrinsic::nvvm_d2i_rz: + // d2ui + case Intrinsic::nvvm_d2ui_rm: + case Intrinsic::nvvm_d2ui_rn: + case Intrinsic::nvvm_d2ui_rp: + case Intrinsic::nvvm_d2ui_rz: + // f2ll + case Intrinsic::nvvm_f2ll_rm: + case Intrinsic::nvvm_f2ll_rn: + case Intrinsic::nvvm_f2ll_rp: + case Intrinsic::nvvm_f2ll_rz: + case Intrinsic::nvvm_f2ll_rm_ftz: + case Intrinsic::nvvm_f2ll_rn_ftz: + case Intrinsic::nvvm_f2ll_rp_ftz: + case Intrinsic::nvvm_f2ll_rz_ftz: + // f2ull + case Intrinsic::nvvm_f2ull_rm: + case Intrinsic::nvvm_f2ull_rn: + case Intrinsic::nvvm_f2ull_rp: + case Intrinsic::nvvm_f2ull_rz: + case Intrinsic::nvvm_f2ull_rm_ftz: + case Intrinsic::nvvm_f2ull_rn_ftz: + case Intrinsic::nvvm_f2ull_rp_ftz: + case Intrinsic::nvvm_f2ull_rz_ftz: + // d2ll + case Intrinsic::nvvm_d2ll_rm: + case Intrinsic::nvvm_d2ll_rn: + case Intrinsic::nvvm_d2ll_rp: + case Intrinsic::nvvm_d2ll_rz: + // d2ull + case Intrinsic::nvvm_d2ull_rm: + case Intrinsic::nvvm_d2ull_rn: + case Intrinsic::nvvm_d2ull_rp: + case Intrinsic::nvvm_d2ull_rz: { + // In float-to-integer conversion, NaN inputs are converted to 0. + if (U.isNaN()) + return ConstantInt::get(Ty, 0); + + APFloat::roundingMode RMode = nvvm::IntrinsicGetRoundingMode(IntrinsicID); + bool IsFTZ = nvvm::IntrinsicShouldFTZ(IntrinsicID); + bool IsSigned = nvvm::IntrinsicConvertsToSignedInteger(IntrinsicID); + + APSInt ResInt(Ty->getIntegerBitWidth(), !IsSigned); + auto FloatToRound = IsFTZ ? FTZPreserveSign(U) : U; + + bool IsExact = false; + APFloat::opStatus Status = + FloatToRound.convertToInteger(ResInt, RMode, &IsExact); + + if (Status != APFloat::opInvalidOp) + return ConstantInt::get(Ty, ResInt); + return nullptr; + } + } + /// We only fold functions with finite arguments. Folding NaN and inf is /// likely to be aborted with an exception anyway, and some host libms /// have known errors raising exceptions. diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 8567a05..999386c 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -4275,25 +4275,27 @@ Value *llvm::simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } -static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const SimplifyQuery &Q, - bool AllowRefinement, - SmallVectorImpl<Instruction *> *DropFlags, - unsigned MaxRecurse) { +static Value *simplifyWithOpsReplaced(Value *V, + ArrayRef<std::pair<Value *, Value *>> Ops, + const SimplifyQuery &Q, + bool AllowRefinement, + SmallVectorImpl<Instruction *> *DropFlags, + unsigned MaxRecurse) { assert((AllowRefinement || !Q.CanUseUndef) && "If AllowRefinement=false then CanUseUndef=false"); + for (const auto &OpAndRepOp : Ops) { + // We cannot replace a constant, and shouldn't even try. + if (isa<Constant>(OpAndRepOp.first)) + return nullptr; - // Trivial replacement. - if (V == Op) - return RepOp; + // Trivial replacement. + if (V == OpAndRepOp.first) + return OpAndRepOp.second; + } if (!MaxRecurse--) return nullptr; - // We cannot replace a constant, and shouldn't even try. - if (isa<Constant>(Op)) - return nullptr; - auto *I = dyn_cast<Instruction>(V); if (!I) return nullptr; @@ -4303,11 +4305,6 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (isa<PHINode>(I)) return nullptr; - // For vector types, the simplification must hold per-lane, so forbid - // potentially cross-lane operations like shufflevector. - if (Op->getType()->isVectorTy() && !isNotCrossLaneOperation(I)) - return nullptr; - // Don't fold away llvm.is.constant checks based on assumptions. if (match(I, m_Intrinsic<Intrinsic::is_constant>())) return nullptr; @@ -4316,12 +4313,20 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (isa<FreezeInst>(I)) return nullptr; + for (const auto &OpAndRepOp : Ops) { + // For vector types, the simplification must hold per-lane, so forbid + // potentially cross-lane operations like shufflevector. + if (OpAndRepOp.first->getType()->isVectorTy() && + !isNotCrossLaneOperation(I)) + return nullptr; + } + // Replace Op with RepOp in instruction operands. SmallVector<Value *, 8> NewOps; bool AnyReplaced = false; for (Value *InstOp : I->operands()) { - if (Value *NewInstOp = simplifyWithOpReplaced( - InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) { + if (Value *NewInstOp = simplifyWithOpsReplaced( + InstOp, Ops, Q, AllowRefinement, DropFlags, MaxRecurse)) { NewOps.push_back(NewInstOp); AnyReplaced = InstOp != NewInstOp; } else { @@ -4372,7 +4377,8 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, // by assumption and this case never wraps, so nowrap flags can be // ignored. if ((Opcode == Instruction::Sub || Opcode == Instruction::Xor) && - NewOps[0] == RepOp && NewOps[1] == RepOp) + NewOps[0] == NewOps[1] && + any_of(Ops, [=](const auto &Rep) { return NewOps[0] == Rep.second; })) return Constant::getNullValue(I->getType()); // If we are substituting an absorber constant into a binop and extra @@ -4382,10 +4388,10 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, // (Op == 0) ? 0 : (Op & -Op) --> Op & -Op // (Op == 0) ? 0 : (Op * (binop Op, C)) --> Op * (binop Op, C) // (Op == -1) ? -1 : (Op | (binop C, Op) --> Op | (binop C, Op) - Constant *Absorber = - ConstantExpr::getBinOpAbsorber(Opcode, I->getType()); + Constant *Absorber = ConstantExpr::getBinOpAbsorber(Opcode, I->getType()); if ((NewOps[0] == Absorber || NewOps[1] == Absorber) && - impliesPoison(BO, Op)) + any_of(Ops, + [=](const auto &Rep) { return impliesPoison(BO, Rep.first); })) return Absorber; } @@ -4453,6 +4459,15 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, /*AllowNonDeterministic=*/false); } +static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const SimplifyQuery &Q, + bool AllowRefinement, + SmallVectorImpl<Instruction *> *DropFlags, + unsigned MaxRecurse) { + return simplifyWithOpsReplaced(V, {{Op, RepOp}}, Q, AllowRefinement, + DropFlags, MaxRecurse); +} + Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement, @@ -4595,17 +4610,24 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS, /// Try to simplify a select instruction when its condition operand is an /// integer equality or floating-point equivalence comparison. -static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS, - Value *TrueVal, Value *FalseVal, - const SimplifyQuery &Q, - unsigned MaxRecurse) { - if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(), - /* AllowRefinement */ false, - /* DropFlags */ nullptr, MaxRecurse) == TrueVal) - return FalseVal; - if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ true, - /* DropFlags */ nullptr, MaxRecurse) == FalseVal) +static Value *simplifySelectWithEquivalence( + ArrayRef<std::pair<Value *, Value *>> Replacements, Value *TrueVal, + Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) { + Value *SimplifiedFalseVal = + simplifyWithOpsReplaced(FalseVal, Replacements, Q.getWithoutUndef(), + /* AllowRefinement */ false, + /* DropFlags */ nullptr, MaxRecurse); + if (!SimplifiedFalseVal) + SimplifiedFalseVal = FalseVal; + + Value *SimplifiedTrueVal = + simplifyWithOpsReplaced(TrueVal, Replacements, Q, + /* AllowRefinement */ true, + /* DropFlags */ nullptr, MaxRecurse); + if (!SimplifiedTrueVal) + SimplifiedTrueVal = TrueVal; + + if (SimplifiedFalseVal == SimplifiedTrueVal) return FalseVal; return nullptr; @@ -4699,10 +4721,10 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, // the arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (Value *V = simplifySelectWithEquivalence(CmpLHS, CmpRHS, TrueVal, + if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse)) return V; - if (Value *V = simplifySelectWithEquivalence(CmpRHS, CmpLHS, TrueVal, + if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, TrueVal, FalseVal, Q, MaxRecurse)) return V; @@ -4712,11 +4734,8 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) && match(CmpRHS, m_Zero())) { // (X | Y) == 0 implies X == 0 and Y == 0. - if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal, - Q, MaxRecurse)) - return V; - if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal, - Q, MaxRecurse)) + if (Value *V = simplifySelectWithEquivalence( + {{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse)) return V; } @@ -4724,11 +4743,8 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) && match(CmpRHS, m_AllOnes())) { // (X & Y) == -1 implies X == -1 and Y == -1. - if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal, - Q, MaxRecurse)) - return V; - if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal, - Q, MaxRecurse)) + if (Value *V = simplifySelectWithEquivalence( + {{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse)) return V; } } @@ -4757,11 +4773,11 @@ static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F, // This transforms is safe if at least one operand is known to not be zero. // Otherwise, the select can change the sign of a zero operand. if (IsEquiv) { - if (Value *V = - simplifySelectWithEquivalence(CmpLHS, CmpRHS, T, F, Q, MaxRecurse)) + if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, T, F, Q, + MaxRecurse)) return V; - if (Value *V = - simplifySelectWithEquivalence(CmpRHS, CmpLHS, T, F, Q, MaxRecurse)) + if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, T, F, Q, + MaxRecurse)) return V; } diff --git a/llvm/lib/Analysis/Lint.cpp b/llvm/lib/Analysis/Lint.cpp index 4689451..e9d96a0c 100644 --- a/llvm/lib/Analysis/Lint.cpp +++ b/llvm/lib/Analysis/Lint.cpp @@ -266,6 +266,30 @@ void Lint::visitCallBase(CallBase &I) { visitMemoryReference(I, Loc, DL->getABITypeAlign(Ty), Ty, MemRef::Read | MemRef::Write); } + + // Check that ABI attributes for the function and call-site match. + unsigned ArgNo = AI->getOperandNo(); + Attribute::AttrKind ABIAttributes[] = { + Attribute::ZExt, Attribute::SExt, Attribute::InReg, + Attribute::ByVal, Attribute::ByRef, Attribute::InAlloca, + Attribute::Preallocated, Attribute::StructRet}; + AttributeList CallAttrs = I.getAttributes(); + for (Attribute::AttrKind Attr : ABIAttributes) { + Attribute CallAttr = CallAttrs.getParamAttr(ArgNo, Attr); + Attribute FnAttr = F->getParamAttribute(ArgNo, Attr); + Check(CallAttr.isValid() == FnAttr.isValid(), + Twine("Undefined behavior: ABI attribute ") + + Attribute::getNameFromAttrKind(Attr) + + " not present on both function and call-site", + &I); + if (CallAttr.isValid() && FnAttr.isValid()) { + Check(CallAttr == FnAttr, + Twine("Undefined behavior: ABI attribute ") + + Attribute::getNameFromAttrKind(Attr) + + " does not have same argument for function and call-site", + &I); + } + } } } } diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp index 54b9521..bc03e40 100644 --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -25,10 +25,9 @@ using namespace llvm; -static bool isAligned(const Value *Base, const APInt &Offset, Align Alignment, +static bool isAligned(const Value *Base, Align Alignment, const DataLayout &DL) { - Align BA = Base->getPointerAlignment(DL); - return BA >= Alignment && Offset.isAligned(BA); + return Base->getPointerAlignment(DL) >= Alignment; } /// Test if V is always a pointer to allocated and suitably aligned memory for @@ -118,8 +117,7 @@ static bool isDereferenceableAndAlignedPointer( // As we recursed through GEPs to get here, we've incrementally checked // that each step advanced by a multiple of the alignment. If our base is // properly aligned, then the original offset accessed must also be. - APInt Offset(DL.getTypeStoreSizeInBits(V->getType()), 0); - return isAligned(V, Offset, Alignment, DL); + return isAligned(V, Alignment, DL); } /// TODO refactor this function to be able to search independently for @@ -154,8 +152,7 @@ static bool isDereferenceableAndAlignedPointer( // checked that each step advanced by a multiple of the alignment. If // our base is properly aligned, then the original offset accessed // must also be. - APInt Offset(DL.getTypeStoreSizeInBits(V->getType()), 0); - return isAligned(V, Offset, Alignment, DL); + return isAligned(V, Alignment, DL); } } } diff --git a/llvm/lib/Analysis/MemoryProfileInfo.cpp b/llvm/lib/Analysis/MemoryProfileInfo.cpp index 1c3f589..2f3c87a 100644 --- a/llvm/lib/Analysis/MemoryProfileInfo.cpp +++ b/llvm/lib/Analysis/MemoryProfileInfo.cpp @@ -347,3 +347,20 @@ template <> uint64_t CallStack<MDNode, MDNode::op_iterator>::back() const { return mdconst::dyn_extract<ConstantInt>(N->operands().back()) ->getZExtValue(); } + +MDNode *MDNode::getMergedMemProfMetadata(MDNode *A, MDNode *B) { + // TODO: Support more sophisticated merging, such as selecting the one with + // more bytes allocated, or implement support for carrying multiple allocation + // leaf contexts. For now, keep the first one. + if (A) + return A; + return B; +} + +MDNode *MDNode::getMergedCallsiteMetadata(MDNode *A, MDNode *B) { + // TODO: Support more sophisticated merging, which will require support for + // carrying multiple contexts. For now, keep the first one. + if (A) + return A; + return B; +} diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 8ab5602..7e18f7c 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -226,7 +226,7 @@ static cl::opt<unsigned> RangeIterThreshold( static cl::opt<unsigned> MaxLoopGuardCollectionDepth( "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, - cl::desc("Maximum depth for recrusive loop guard collection"), cl::init(1)); + cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1)); static cl::opt<bool> ClassifyExpressions("scalar-evolution-classify-expressions", @@ -15765,6 +15765,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // original header. // TODO: share this logic with isLoopEntryGuardedByCond. unsigned NumCollectedConditions = 0; + VisitedBlocks.insert(Block); std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block); for (; Pair.first; Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) { diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 78fec25..0eb43dd 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -1119,7 +1119,8 @@ static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II, KnownBits &Known) { const APInt *CLow, *CHigh; if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh)) - Known = Known.unionWith(ConstantRange(*CLow, *CHigh + 1).toKnownBits()); + Known = Known.unionWith( + ConstantRange::getNonEmpty(*CLow, *CHigh + 1).toKnownBits()); } static void computeKnownBitsFromOperator(const Operator *I, @@ -8640,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred, } } +std::optional<std::pair<CmpPredicate, Constant *>> +llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) { + assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && + "Only for relational integer predicates."); + if (isa<UndefValue>(C)) + return std::nullopt; + + Type *Type = C->getType(); + bool IsSigned = ICmpInst::isSigned(Pred); + + CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); + bool WillIncrement = + UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; + + // Check if the constant operand can be safely incremented/decremented + // without overflowing/underflowing. + auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { + return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); + }; + + Constant *SafeReplacementConstant = nullptr; + if (auto *CI = dyn_cast<ConstantInt>(C)) { + // Bail out if the constant can't be safely incremented/decremented. + if (!ConstantIsOk(CI)) + return std::nullopt; + } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) { + unsigned NumElts = FVTy->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt) + return std::nullopt; + + if (isa<UndefValue>(Elt)) + continue; + + // Bail out if we can't determine if this constant is min/max or if we + // know that this constant is min/max. + auto *CI = dyn_cast<ConstantInt>(Elt); + if (!CI || !ConstantIsOk(CI)) + return std::nullopt; + + if (!SafeReplacementConstant) + SafeReplacementConstant = CI; + } + } else if (isa<VectorType>(C->getType())) { + // Handle scalable splat + Value *SplatC = C->getSplatValue(); + auto *CI = dyn_cast_or_null<ConstantInt>(SplatC); + // Bail out if the constant can't be safely incremented/decremented. + if (!CI || !ConstantIsOk(CI)) + return std::nullopt; + } else { + // ConstantExpr? + return std::nullopt; + } + + // It may not be safe to change a compare predicate in the presence of + // undefined elements, so replace those elements with the first safe constant + // that we found. + // TODO: in case of poison, it is safe; let's replace undefs only. + if (C->containsUndefOrPoisonElement()) { + assert(SafeReplacementConstant && "Replacement constant not set"); + C = Constant::replaceUndefsWith(C, SafeReplacementConstant); + } + + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); + + // Increment or decrement the constant. + Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); + Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); + + return std::make_pair(NewPred, NewC); +} + static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, FastMathFlags FMF, Value *CmpLHS, Value *CmpRHS, diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index 7bd3fb3..3ba4590 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -3914,21 +3914,22 @@ static void emitGlobalConstantImpl(const DataLayout &DL, const Constant *CV, if (isa<ConstantAggregateZero>(CV)) { StructType *structType; if (AliasList && (structType = llvm::dyn_cast<StructType>(CV->getType()))) { - // Handle cases of aliases to direct struct elements - const StructLayout *Layout = DL.getStructLayout(structType); - uint64_t SizeSoFar = 0; - for (unsigned int i = 0, n = structType->getNumElements(); i < n - 1; - ++i) { - uint64_t GapToNext = Layout->getElementOffset(i + 1) - SizeSoFar; - AP.OutStreamer->emitZeros(GapToNext); - SizeSoFar += GapToNext; - emitGlobalAliasInline(AP, Offset + SizeSoFar, AliasList); + unsigned numElements = {structType->getNumElements()}; + if (numElements != 0) { + // Handle cases of aliases to direct struct elements + const StructLayout *Layout = DL.getStructLayout(structType); + uint64_t SizeSoFar = 0; + for (unsigned int i = 0; i < numElements - 1; ++i) { + uint64_t GapToNext = Layout->getElementOffset(i + 1) - SizeSoFar; + AP.OutStreamer->emitZeros(GapToNext); + SizeSoFar += GapToNext; + emitGlobalAliasInline(AP, Offset + SizeSoFar, AliasList); + } + AP.OutStreamer->emitZeros(Size - SizeSoFar); + return; } - AP.OutStreamer->emitZeros(Size - SizeSoFar); - return; - } else { - return AP.OutStreamer->emitZeros(Size); } + return AP.OutStreamer->emitZeros(Size); } if (isa<UndefValue>(CV)) diff --git a/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp b/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp index e1291e2..11de4b6 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp @@ -3789,6 +3789,7 @@ void DwarfDebug::addDwarfTypeUnitType(DwarfCompileUnit &CU, // they depend on addresses, throwing them out and rebuilding them. setCurrentDWARF5AccelTable(DWARF5AccelTableKind::CU); CU.constructTypeDIE(RefDie, cast<DICompositeType>(CTy)); + CU.updateAcceleratorTables(CTy->getScope(), CTy, RefDie); return; } diff --git a/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.h b/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.h index 0225654..1632053 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.h +++ b/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.h @@ -315,6 +315,11 @@ public: /// Get context owner's DIE. DIE *createTypeDIE(const DICompositeType *Ty); + /// If this is a named finished type then include it in the list of types for + /// the accelerator tables. + void updateAcceleratorTables(const DIScope *Context, const DIType *Ty, + const DIE &TyDIE); + protected: ~DwarfUnit(); @@ -357,11 +362,6 @@ private: virtual void finishNonUnitTypeDIE(DIE& D, const DICompositeType *CTy) = 0; - /// If this is a named finished type then include it in the list of types for - /// the accelerator tables. - void updateAcceleratorTables(const DIScope *Context, const DIType *Ty, - const DIE &TyDIE); - virtual bool isDwoUnit() const = 0; const MCSymbol *getCrossSectionRelativeBaseAddress() const override; diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 5c712e4..ba1b10e 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -152,7 +152,7 @@ static cl::opt<bool> static cl::opt<bool> EnableAndCmpSinking("enable-andcmp-sinking", cl::Hidden, cl::init(true), - cl::desc("Enable sinkinig and/cmp into branches.")); + cl::desc("Enable sinking and/cmp into branches.")); static cl::opt<bool> DisableStoreExtract( "disable-cgp-store-extract", cl::Hidden, cl::init(false), diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp index f3f7ea9..aec8df9 100644 --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -108,6 +108,13 @@ static bool isNeg(Value *V); static Value *getNegOperand(Value *V); namespace { +template <typename T, typename IterT> +std::optional<T> findCommonBetweenCollections(IterT A, IterT B) { + auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); }); + if (Common != A.end()) + return std::make_optional(*Common); + return std::nullopt; +} class ComplexDeinterleavingLegacyPass : public FunctionPass { public: @@ -144,6 +151,7 @@ private: friend class ComplexDeinterleavingGraph; using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; using RawNodePtr = ComplexDeinterleavingCompositeNode *; + bool OperandsValid = true; public: ComplexDeinterleavingOperation Operation; @@ -160,7 +168,11 @@ public: SmallVector<RawNodePtr> Operands; Value *ReplacementNode = nullptr; - void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } + void addOperand(NodePtr Node) { + if (!Node || !Node.get()) + OperandsValid = false; + Operands.push_back(Node.get()); + } void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { @@ -194,6 +206,8 @@ public: PrintNodeRef(Op); } } + + bool areOperandsValid() { return OperandsValid; } }; class ComplexDeinterleavingGraph { @@ -293,7 +307,7 @@ private: NodePtr submitCompositeNode(NodePtr Node) { CompositeNodes.push_back(Node); - if (Node->Real && Node->Imag) + if (Node->Real) CachedResult[{Node->Real, Node->Imag}] = Node; return Node; } @@ -327,6 +341,8 @@ private: /// i: ai - br NodePtr identifyAdd(Instruction *Real, Instruction *Imag); NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); + NodePtr identifyPartialReduction(Value *R, Value *I); + NodePtr identifyDotProduct(Value *Inst); NodePtr identifyNode(Value *R, Value *I); @@ -396,6 +412,7 @@ private: /// * Deinterleave the final value outside of the loop and repurpose original /// reduction users void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); + void processReductionSingle(Value *OperationReplacement, RawNodePtr Node); public: void dump() { dump(dbgs()); } @@ -891,17 +908,163 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, } ComplexDeinterleavingGraph::NodePtr -ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { - LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); - assert(R->getType() == I->getType() && - "Real and imaginary parts should not have different types"); +ComplexDeinterleavingGraph::identifyDotProduct(Value *V) { + + if (!TL->isComplexDeinterleavingOperationSupported( + ComplexDeinterleavingOperation::CDot, V->getType())) { + LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving " + "operation CDot with the type " + << *V->getType() << "\n"); + return nullptr; + } + + auto *Inst = cast<Instruction>(V); + auto *RealUser = cast<Instruction>(*Inst->user_begin()); + + NodePtr CN = + prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr); + + NodePtr ANode; + + const Intrinsic::ID PartialReduceInt = + Intrinsic::experimental_vector_partial_reduce_add; + + Value *AReal = nullptr; + Value *AImag = nullptr; + Value *BReal = nullptr; + Value *BImag = nullptr; + Value *Phi = nullptr; + + auto UnwrapCast = [](Value *V) -> Value * { + if (auto *CI = dyn_cast<CastInst>(V)) + return CI->getOperand(0); + return V; + }; + + auto PatternRot0 = m_Intrinsic<PartialReduceInt>( + m_Intrinsic<PartialReduceInt>(m_Value(Phi), + m_Mul(m_Value(BReal), m_Value(AReal))), + m_Neg(m_Mul(m_Value(BImag), m_Value(AImag)))); + + auto PatternRot270 = m_Intrinsic<PartialReduceInt>( + m_Intrinsic<PartialReduceInt>( + m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))), + m_Mul(m_Value(BImag), m_Value(AReal))); + + if (match(Inst, PatternRot0)) { + CN->Rotation = ComplexDeinterleavingRotation::Rotation_0; + } else if (match(Inst, PatternRot270)) { + CN->Rotation = ComplexDeinterleavingRotation::Rotation_270; + } else { + Value *A0, *A1; + // The rotations 90 and 180 share the same operation pattern, so inspect the + // order of the operands, identifying where the real and imaginary + // components of A go, to discern between the aforementioned rotations. + auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>( + m_Intrinsic<PartialReduceInt>(m_Value(Phi), + m_Mul(m_Value(BReal), m_Value(A0))), + m_Mul(m_Value(BImag), m_Value(A1))); + + if (!match(Inst, PatternRot90Rot180)) + return nullptr; + + A0 = UnwrapCast(A0); + A1 = UnwrapCast(A1); + + // Test if A0 is real/A1 is imag + ANode = identifyNode(A0, A1); + if (!ANode) { + // Test if A0 is imag/A1 is real + ANode = identifyNode(A1, A0); + // Unable to identify operand components, thus unable to identify rotation + if (!ANode) + return nullptr; + CN->Rotation = ComplexDeinterleavingRotation::Rotation_90; + AReal = A1; + AImag = A0; + } else { + AReal = A0; + AImag = A1; + CN->Rotation = ComplexDeinterleavingRotation::Rotation_180; + } + } + + AReal = UnwrapCast(AReal); + AImag = UnwrapCast(AImag); + BReal = UnwrapCast(BReal); + BImag = UnwrapCast(BImag); + + VectorType *VTy = cast<VectorType>(V->getType()); + Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2); + if (AReal->getType() != ExpectedOperandTy) + return nullptr; + if (AImag->getType() != ExpectedOperandTy) + return nullptr; + if (BReal->getType() != ExpectedOperandTy) + return nullptr; + if (BImag->getType() != ExpectedOperandTy) + return nullptr; + + if (Phi->getType() != VTy && RealUser->getType() != VTy) + return nullptr; + + NodePtr Node = identifyNode(AReal, AImag); + + // In the case that a node was identified to figure out the rotation, ensure + // that trying to identify a node with AReal and AImag post-unwrap results in + // the same node + if (ANode && Node != ANode) { + LLVM_DEBUG( + dbgs() + << "Identified node is different from previously identified node. " + "Unable to confidently generate a complex operation node\n"); + return nullptr; + } + + CN->addOperand(Node); + CN->addOperand(identifyNode(BReal, BImag)); + CN->addOperand(identifyNode(Phi, RealUser)); + + return submitCompositeNode(CN); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) { + // Partial reductions don't support non-vector types, so check these first + if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType())) + return nullptr; + + auto CommonUser = + findCommonBetweenCollections<Value *>(R->users(), I->users()); + if (!CommonUser) + return nullptr; + + auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser); + if (!IInst || IInst->getIntrinsicID() != + Intrinsic::experimental_vector_partial_reduce_add) + return nullptr; + + if (NodePtr CN = identifyDotProduct(IInst)) + return CN; + + return nullptr; +} +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { auto It = CachedResult.find({R, I}); if (It != CachedResult.end()) { LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); return It->second; } + if (NodePtr CN = identifyPartialReduction(R, I)) + return CN; + + bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I); + if (!IsReduction && R->getType() != I->getType()) + return nullptr; + if (NodePtr CN = identifySplat(R, I)) return CN; @@ -1427,12 +1590,20 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { if (It != RootToNode.end()) { auto RootNode = It->second; assert(RootNode->Operation == - ComplexDeinterleavingOperation::ReductionOperation); + ComplexDeinterleavingOperation::ReductionOperation || + RootNode->Operation == + ComplexDeinterleavingOperation::ReductionSingle); // Find out which part, Real or Imag, comes later, and only if we come to // the latest part, add it to OrderedRoots. auto *R = cast<Instruction>(RootNode->Real); - auto *I = cast<Instruction>(RootNode->Imag); - auto *ReplacementAnchor = R->comesBefore(I) ? I : R; + auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr; + + Instruction *ReplacementAnchor; + if (I) + ReplacementAnchor = R->comesBefore(I) ? I : R; + else + ReplacementAnchor = R; + if (ReplacementAnchor != RootI) return false; OrderedRoots.push_back(RootI); @@ -1523,7 +1694,6 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() { for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { if (Processed[j]) continue; - auto *Real = OperationInstruction[i]; auto *Imag = OperationInstruction[j]; if (Real->getType() != Imag->getType()) @@ -1556,6 +1726,28 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() { break; } } + + auto *Real = OperationInstruction[i]; + // We want to check that we have 2 operands, but the function attributes + // being counted as operands bloats this value. + if (Real->getNumOperands() < 2) + continue; + + RealPHI = ReductionInfo[Real].first; + ImagPHI = nullptr; + PHIsFound = false; + auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1)); + if (Node && PHIsFound) { + LLVM_DEBUG( + dbgs() << "Identified single reduction starting from instruction: " + << *Real << "/" << *ReductionInfo[Real].second << "\n"); + Processed[i] = true; + auto RootNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr); + RootNode->addOperand(Node); + RootToNode[Real] = RootNode; + submitCompositeNode(RootNode); + } } RealPHI = nullptr; @@ -1563,6 +1755,24 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() { } bool ComplexDeinterleavingGraph::checkNodes() { + + bool FoundDeinterleaveNode = false; + for (NodePtr N : CompositeNodes) { + if (!N->areOperandsValid()) + return false; + if (N->Operation == ComplexDeinterleavingOperation::Deinterleave) + FoundDeinterleaveNode = true; + } + + // We need a deinterleave node in order to guarantee that we're working with + // complex numbers. + if (!FoundDeinterleaveNode) { + LLVM_DEBUG( + dbgs() << "Couldn't find a deinterleave node within the graph, cannot " + "guarantee safety during graph transformation.\n"); + return false; + } + // Collect all instructions from roots to leaves SmallPtrSet<Instruction *, 16> AllInstructions; SmallVector<Instruction *, 8> Worklist; @@ -1831,7 +2041,7 @@ ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, Instruction *Imag) { - if (Real != RealPHI || Imag != ImagPHI) + if (Real != RealPHI || (ImagPHI && Imag != ImagPHI)) return nullptr; PHIsFound = true; @@ -1926,6 +2136,16 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, Value *ReplacementNode; switch (Node->Operation) { + case ComplexDeinterleavingOperation::CDot: { + Value *Input0 = ReplaceOperandIfExist(Node, 0); + Value *Input1 = ReplaceOperandIfExist(Node, 1); + Value *Accumulator = ReplaceOperandIfExist(Node, 2); + assert(!Input1 || (Input0->getType() == Input1->getType() && + "Node inputs need to be of the same type")); + ReplacementNode = TL->createComplexDeinterleavingIR( + Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); + break; + } case ComplexDeinterleavingOperation::CAdd: case ComplexDeinterleavingOperation::CMulPartial: case ComplexDeinterleavingOperation::Symmetric: { @@ -1969,13 +2189,18 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, case ComplexDeinterleavingOperation::ReductionPHI: { // If Operation is ReductionPHI, a new empty PHINode is created. // It is filled later when the ReductionOperation is processed. + auto *OldPHI = cast<PHINode>(Node->Real); auto *VTy = cast<VectorType>(Node->Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt()); - OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; + OldToNewPHI[OldPHI] = NewPHI; ReplacementNode = NewPHI; break; } + case ComplexDeinterleavingOperation::ReductionSingle: + ReplacementNode = replaceNode(Builder, Node->Operands[0]); + processReductionSingle(ReplacementNode, Node); + break; case ComplexDeinterleavingOperation::ReductionOperation: ReplacementNode = replaceNode(Builder, Node->Operands[0]); processReductionOperation(ReplacementNode, Node); @@ -2000,6 +2225,38 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, return ReplacementNode; } +void ComplexDeinterleavingGraph::processReductionSingle( + Value *OperationReplacement, RawNodePtr Node) { + auto *Real = cast<Instruction>(Node->Real); + auto *OldPHI = ReductionInfo[Real].first; + auto *NewPHI = OldToNewPHI[OldPHI]; + auto *VTy = cast<VectorType>(Real->getType()); + auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); + + Value *Init = OldPHI->getIncomingValueForBlock(Incoming); + + IRBuilder<> Builder(Incoming->getTerminator()); + + Value *NewInit = nullptr; + if (auto *C = dyn_cast<Constant>(Init)) { + if (C->isZeroValue()) + NewInit = Constant::getNullValue(NewVTy); + } + + if (!NewInit) + NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, + {Init, Constant::getNullValue(VTy)}); + + NewPHI->addIncoming(NewInit, Incoming); + NewPHI->addIncoming(OperationReplacement, BackEdge); + + auto *FinalReduction = ReductionInfo[Real].second; + Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt()); + + auto *AddReduce = Builder.CreateAddReduce(OperationReplacement); + FinalReduction->replaceAllUsesWith(AddReduce); +} + void ComplexDeinterleavingGraph::processReductionOperation( Value *OperationReplacement, RawNodePtr Node) { auto *Real = cast<Instruction>(Node->Real); @@ -2059,8 +2316,13 @@ void ComplexDeinterleavingGraph::replaceNodes() { auto *RootImag = cast<Instruction>(RootNode->Imag); ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); - DeadInstrRoots.push_back(cast<Instruction>(RootReal)); - DeadInstrRoots.push_back(cast<Instruction>(RootImag)); + DeadInstrRoots.push_back(RootReal); + DeadInstrRoots.push_back(RootImag); + } else if (RootNode->Operation == + ComplexDeinterleavingOperation::ReductionSingle) { + auto *RootInst = cast<Instruction>(RootNode->Real); + ReductionInfo[RootInst].first->removeIncomingValue(BackEdge); + DeadInstrRoots.push_back(ReductionInfo[RootInst].second); } else { assert(R && "Unable to find replacement for RootInstruction"); DeadInstrRoots.push_back(RootInstruction); diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp index f8ca7e3..74f93e1 100644 --- a/llvm/lib/CodeGen/ExpandMemCmp.cpp +++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp @@ -669,17 +669,25 @@ Value *MemCmpExpansion::getMemCmpOneBlock() { if (CI->hasOneUser()) { auto *UI = cast<Instruction>(*CI->user_begin()); CmpPredicate Pred = ICmpInst::Predicate::BAD_ICMP_PREDICATE; - uint64_t Shift; bool NeedsZExt = false; // This is a special case because instead of checking if the result is less // than zero: // bool result = memcmp(a, b, NBYTES) < 0; // Compiler is clever enough to generate the following code: // bool result = memcmp(a, b, NBYTES) >> 31; - if (match(UI, m_LShr(m_Value(), m_ConstantInt(Shift))) && - Shift == (CI->getType()->getIntegerBitWidth() - 1)) { + if (match(UI, + m_LShr(m_Value(), + m_SpecificInt(CI->getType()->getIntegerBitWidth() - 1)))) { Pred = ICmpInst::ICMP_SLT; NeedsZExt = true; + } else if (match(UI, m_SpecificICmp(ICmpInst::ICMP_SGT, m_Specific(CI), + m_AllOnes()))) { + // Adjust predicate as if it compared with 0. + Pred = ICmpInst::ICMP_SGE; + } else if (match(UI, m_SpecificICmp(ICmpInst::ICMP_SLT, m_Specific(CI), + m_One()))) { + // Adjust predicate as if it compared with 0. + Pred = ICmpInst::ICMP_SLE; } else { // In case of a successful match this call will set `Pred` variable match(UI, m_ICmp(Pred, m_Specific(CI), m_Zero())); @@ -696,17 +704,9 @@ Value *MemCmpExpansion::getMemCmpOneBlock() { } } - // The result of memcmp is negative, zero, or positive, so produce that by - // subtracting 2 extended compare bits: sub (ugt, ult). - // If a target prefers to use selects to get -1/0/1, they should be able - // to transform this later. The inverse transform (going from selects to math) - // may not be possible in the DAG because the selects got converted into - // branches before we got there. - Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs); - Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs); - Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); - Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); - return Builder.CreateSub(ZextUGT, ZextULT); + // The result of memcmp is negative, zero, or positive. + return Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::ucmp, + {Loads.Lhs, Loads.Rhs}); } // This function expands the memcmp call into an inline expansion and returns diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp index c20e9d0..4e3aaf5d 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -6864,6 +6864,23 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select, }; return true; } + + // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2) + if (FalseValue.isPowerOf2() && TrueValue.isZero()) { + MatchInfo = [=](MachineIRBuilder &B) { + B.setInstrAndDebugLoc(*Select); + Register Not = MRI.createGenericVirtualRegister(CondTy); + B.buildNot(Not, Cond); + Register Inner = MRI.createGenericVirtualRegister(TrueTy); + B.buildZExtOrTrunc(Inner, Not); + // The shift amount must be scalar. + LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; + auto ShAmtC = B.buildConstant(ShiftTy, FalseValue.exactLogBase2()); + B.buildShl(Dest, Inner, ShAmtC, Flags); + }; + return true; + } + // select Cond, -1, C --> or (sext Cond), C if (TrueValue.isAllOnes()) { MatchInfo = [=](MachineIRBuilder &B) { @@ -7045,6 +7062,34 @@ bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO, } } +// (neg (min/max x, (neg x))) --> (max/min x, (neg x)) +bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI, + BuildFnTy &MatchInfo) const { + assert(MI.getOpcode() == TargetOpcode::G_SUB); + Register DestReg = MI.getOperand(0).getReg(); + LLT DestTy = MRI.getType(DestReg); + + Register X; + Register Sub0; + auto NegPattern = m_all_of(m_Neg(m_DeferredReg(X)), m_Reg(Sub0)); + if (mi_match(DestReg, MRI, + m_Neg(m_OneUse(m_any_of(m_GSMin(m_Reg(X), NegPattern), + m_GSMax(m_Reg(X), NegPattern), + m_GUMin(m_Reg(X), NegPattern), + m_GUMax(m_Reg(X), NegPattern)))))) { + MachineInstr *MinMaxMI = MRI.getVRegDef(MI.getOperand(2).getReg()); + unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxMI->getOpcode()); + if (isLegal({NewOpc, {DestTy}})) { + MatchInfo = [=](MachineIRBuilder &B) { + B.buildInstr(NewOpc, {DestReg}, {X, Sub0}); + }; + return true; + } + } + + return false; +} + bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const { GSelect *Select = cast<GSelect>(&MI); diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp index e2247f7..d0a6234 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -22,6 +22,7 @@ #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/GlobalISel/Utils.h" +#include "llvm/CodeGen/LowLevelTypeUtils.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" @@ -3022,8 +3023,19 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) { return UnableToLegalize; LLT Ty = MRI.getType(MI.getOperand(0).getReg()); - if (!Ty.isScalar()) - return UnableToLegalize; + assert(!Ty.isPointerOrPointerVector() && "Can't widen type"); + if (!Ty.isScalar()) { + // We need to widen the vector element type. + Observer.changingInstr(MI); + widenScalarSrc(MI, WideTy, 0, TargetOpcode::G_ANYEXT); + // We also need to adjust the MMO to turn this into a truncating store. + MachineMemOperand &MMO = **MI.memoperands_begin(); + MachineFunction &MF = MIRBuilder.getMF(); + auto *NewMMO = MF.getMachineMemOperand(&MMO, MMO.getPointerInfo(), Ty); + MI.setMemRefs(MF, {NewMMO}); + Observer.changedInstr(MI); + return Legalized; + } Observer.changingInstr(MI); @@ -4106,10 +4118,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerStore(GStore &StoreMI) { unsigned StoreWidth = MemTy.getSizeInBits(); unsigned StoreSizeInBits = 8 * MemTy.getSizeInBytes(); - if (StoreWidth != StoreSizeInBits) { - if (SrcTy.isVector()) - return UnableToLegalize; - + if (StoreWidth != StoreSizeInBits && !SrcTy.isVector()) { // Promote to a byte-sized store with upper bits zero if not // storing an integral number of bytes. For example, promote // TRUNCSTORE:i1 X -> TRUNCSTORE:i8 (and X, 1) @@ -4131,9 +4140,8 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerStore(GStore &StoreMI) { } if (MemTy.isVector()) { - // TODO: Handle vector trunc stores if (MemTy != SrcTy) - return UnableToLegalize; + return scalarizeVectorBooleanStore(StoreMI); // TODO: We can do better than scalarizing the vector and at least split it // in half. @@ -4189,6 +4197,50 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerStore(GStore &StoreMI) { } LegalizerHelper::LegalizeResult +LegalizerHelper::scalarizeVectorBooleanStore(GStore &StoreMI) { + Register SrcReg = StoreMI.getValueReg(); + Register PtrReg = StoreMI.getPointerReg(); + LLT SrcTy = MRI.getType(SrcReg); + MachineMemOperand &MMO = **StoreMI.memoperands_begin(); + LLT MemTy = MMO.getMemoryType(); + LLT MemScalarTy = MemTy.getElementType(); + MachineFunction &MF = MIRBuilder.getMF(); + + assert(SrcTy.isVector() && "Expect a vector store type"); + + if (!MemScalarTy.isByteSized()) { + // We need to build an integer scalar of the vector bit pattern. + // It's not legal for us to add padding when storing a vector. + unsigned NumBits = MemTy.getSizeInBits(); + LLT IntTy = LLT::scalar(NumBits); + auto CurrVal = MIRBuilder.buildConstant(IntTy, 0); + LLT IdxTy = getLLTForMVT(TLI.getVectorIdxTy(MF.getDataLayout())); + + for (unsigned I = 0, E = MemTy.getNumElements(); I < E; ++I) { + auto Elt = MIRBuilder.buildExtractVectorElement( + SrcTy.getElementType(), SrcReg, MIRBuilder.buildConstant(IdxTy, I)); + auto Trunc = MIRBuilder.buildTrunc(MemScalarTy, Elt); + auto ZExt = MIRBuilder.buildZExt(IntTy, Trunc); + unsigned ShiftIntoIdx = MF.getDataLayout().isBigEndian() + ? (MemTy.getNumElements() - 1) - I + : I; + auto ShiftAmt = MIRBuilder.buildConstant( + IntTy, ShiftIntoIdx * MemScalarTy.getSizeInBits()); + auto Shifted = MIRBuilder.buildShl(IntTy, ZExt, ShiftAmt); + CurrVal = MIRBuilder.buildOr(IntTy, CurrVal, Shifted); + } + auto PtrInfo = MMO.getPointerInfo(); + auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, IntTy); + MIRBuilder.buildStore(CurrVal, PtrReg, *NewMMO); + StoreMI.eraseFromParent(); + return Legalized; + } + + // TODO: implement simple scalarization. + return UnableToLegalize; +} + +LegalizerHelper::LegalizeResult LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) { switch (MI.getOpcode()) { case TargetOpcode::G_LOAD: { @@ -4653,6 +4705,20 @@ LegalizerHelper::createStackTemporary(TypeSize Bytes, Align Alignment, return MIRBuilder.buildFrameIndex(FramePtrTy, FrameIdx); } +MachineInstrBuilder LegalizerHelper::createStackStoreLoad(const DstOp &Res, + const SrcOp &Val) { + LLT SrcTy = Val.getLLTTy(MRI); + Align StackTypeAlign = + std::max(getStackTemporaryAlignment(SrcTy), + getStackTemporaryAlignment(Res.getLLTTy(MRI))); + MachinePointerInfo PtrInfo; + auto StackTemp = + createStackTemporary(SrcTy.getSizeInBytes(), StackTypeAlign, PtrInfo); + + MIRBuilder.buildStore(Val, StackTemp, PtrInfo, StackTypeAlign); + return MIRBuilder.buildLoad(Res, StackTemp, PtrInfo, StackTypeAlign); +} + static Register clampVectorIndex(MachineIRBuilder &B, Register IdxReg, LLT VecTy) { LLT IdxTy = B.getMRI()->getType(IdxReg); diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index 7938293..625d556 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -276,6 +276,21 @@ void llvm::reportGISelFailure(MachineFunction &MF, const TargetPassConfig &TPC, reportGISelFailure(MF, TPC, MORE, R); } +unsigned llvm::getInverseGMinMaxOpcode(unsigned MinMaxOpc) { + switch (MinMaxOpc) { + case TargetOpcode::G_SMIN: + return TargetOpcode::G_SMAX; + case TargetOpcode::G_SMAX: + return TargetOpcode::G_SMIN; + case TargetOpcode::G_UMIN: + return TargetOpcode::G_UMAX; + case TargetOpcode::G_UMAX: + return TargetOpcode::G_UMIN; + default: + llvm_unreachable("unrecognized opcode"); + } +} + std::optional<APInt> llvm::getIConstantVRegVal(Register VReg, const MachineRegisterInfo &MRI) { std::optional<ValueAndVReg> ValAndVReg = getIConstantVRegValWithLookThrough( diff --git a/llvm/lib/CodeGen/LiveRegMatrix.cpp b/llvm/lib/CodeGen/LiveRegMatrix.cpp index 9744c47..3367171 100644 --- a/llvm/lib/CodeGen/LiveRegMatrix.cpp +++ b/llvm/lib/CodeGen/LiveRegMatrix.cpp @@ -66,7 +66,7 @@ void LiveRegMatrix::init(MachineFunction &MF, LiveIntervals &pLIS, unsigned NumRegUnits = TRI->getNumRegUnits(); if (NumRegUnits != Matrix.size()) Queries.reset(new LiveIntervalUnion::Query[NumRegUnits]); - Matrix.init(LIUAlloc, NumRegUnits); + Matrix.init(*LIUAlloc, NumRegUnits); // Make sure no stale queries get reused. invalidateVirtRegs(); diff --git a/llvm/lib/CodeGen/MIRSampleProfile.cpp b/llvm/lib/CodeGen/MIRSampleProfile.cpp index 23db09b..9bba50e8 100644 --- a/llvm/lib/CodeGen/MIRSampleProfile.cpp +++ b/llvm/lib/CodeGen/MIRSampleProfile.cpp @@ -46,8 +46,9 @@ static cl::opt<bool> ShowFSBranchProb( cl::desc("Print setting flow sensitive branch probabilities")); static cl::opt<unsigned> FSProfileDebugProbDiffThreshold( "fs-profile-debug-prob-diff-threshold", cl::init(10), - cl::desc("Only show debug message if the branch probility is greater than " - "this value (in percentage).")); + cl::desc( + "Only show debug message if the branch probability is greater than " + "this value (in percentage).")); static cl::opt<unsigned> FSProfileDebugBWThreshold( "fs-profile-debug-bw-threshold", cl::init(10000), diff --git a/llvm/lib/CodeGen/MachineBlockPlacement.cpp b/llvm/lib/CodeGen/MachineBlockPlacement.cpp index 0f68313..05bc4cf 100644 --- a/llvm/lib/CodeGen/MachineBlockPlacement.cpp +++ b/llvm/lib/CodeGen/MachineBlockPlacement.cpp @@ -149,7 +149,7 @@ static cl::opt<unsigned> JumpInstCost("jump-inst-cost", static cl::opt<bool> TailDupPlacement("tail-dup-placement", cl::desc("Perform tail duplication during placement. " - "Creates more fallthrough opportunites in " + "Creates more fallthrough opportunities in " "outline branches."), cl::init(true), cl::Hidden); diff --git a/llvm/lib/CodeGen/MachineBranchProbabilityInfo.cpp b/llvm/lib/CodeGen/MachineBranchProbabilityInfo.cpp index 56fffff..2e92dd8 100644 --- a/llvm/lib/CodeGen/MachineBranchProbabilityInfo.cpp +++ b/llvm/lib/CodeGen/MachineBranchProbabilityInfo.cpp @@ -29,7 +29,7 @@ namespace llvm { cl::opt<unsigned> StaticLikelyProb("static-likely-prob", cl::desc("branch probability threshold in percentage" - "to be considered very likely"), + " to be considered very likely"), cl::init(80), cl::Hidden); cl::opt<unsigned> ProfileLikelyProb( diff --git a/llvm/lib/CodeGen/MachineOperand.cpp b/llvm/lib/CodeGen/MachineOperand.cpp index 3a9bdde..5c9ca91 100644 --- a/llvm/lib/CodeGen/MachineOperand.cpp +++ b/llvm/lib/CodeGen/MachineOperand.cpp @@ -1170,6 +1170,9 @@ void MachineMemOperand::print(raw_ostream &OS, ModuleSlotTracker &MST, if (getFlags() & MachineMemOperand::MOTargetFlag3) OS << '"' << getTargetMMOFlagName(*TII, MachineMemOperand::MOTargetFlag3) << "\" "; + if (getFlags() & MachineMemOperand::MOTargetFlag4) + OS << '"' << getTargetMMOFlagName(*TII, MachineMemOperand::MOTargetFlag4) + << "\" "; } else { if (getFlags() & MachineMemOperand::MOTargetFlag1) OS << "\"MOTargetFlag1\" "; @@ -1177,6 +1180,8 @@ void MachineMemOperand::print(raw_ostream &OS, ModuleSlotTracker &MST, OS << "\"MOTargetFlag2\" "; if (getFlags() & MachineMemOperand::MOTargetFlag3) OS << "\"MOTargetFlag3\" "; + if (getFlags() & MachineMemOperand::MOTargetFlag4) + OS << "\"MOTargetFlag4\" "; } assert((isLoad() || isStore()) && diff --git a/llvm/lib/CodeGen/MachineRegisterInfo.cpp b/llvm/lib/CodeGen/MachineRegisterInfo.cpp index 6f636a1..394b99b 100644 --- a/llvm/lib/CodeGen/MachineRegisterInfo.cpp +++ b/llvm/lib/CodeGen/MachineRegisterInfo.cpp @@ -407,9 +407,11 @@ void MachineRegisterInfo::replaceRegWith(Register FromReg, Register ToReg) { MachineInstr *MachineRegisterInfo::getVRegDef(Register Reg) const { // Since we are in SSA form, we can use the first definition. def_instr_iterator I = def_instr_begin(Reg); - assert((I.atEnd() || std::next(I) == def_instr_end()) && - "getVRegDef assumes a single definition or no definition"); - return !I.atEnd() ? &*I : nullptr; + if (I == def_instr_end()) + return nullptr; + assert(std::next(I) == def_instr_end() && + "getVRegDef assumes at most one definition"); + return &*I; } /// getUniqueVRegDef - Return the unique machine instr that defines the diff --git a/llvm/lib/CodeGen/MachineTraceMetrics.cpp b/llvm/lib/CodeGen/MachineTraceMetrics.cpp index 6576f97..021c1a0 100644 --- a/llvm/lib/CodeGen/MachineTraceMetrics.cpp +++ b/llvm/lib/CodeGen/MachineTraceMetrics.cpp @@ -683,11 +683,10 @@ struct DataDep { DataDep(const MachineRegisterInfo *MRI, unsigned VirtReg, unsigned UseOp) : UseOp(UseOp) { assert(Register::isVirtualRegister(VirtReg)); - MachineRegisterInfo::def_iterator DefI = MRI->def_begin(VirtReg); - assert(!DefI.atEnd() && "Register has no defs"); - DefMI = DefI->getParent(); - DefOp = DefI.getOperandNo(); - assert((++DefI).atEnd() && "Register has multiple defs"); + MachineOperand *DefMO = MRI->getOneDef(VirtReg); + assert(DefMO && "Register does not have unique def"); + DefMI = DefMO->getParent(); + DefOp = DefMO->getOperandNo(); } }; diff --git a/llvm/lib/CodeGen/PostRASchedulerList.cpp b/llvm/lib/CodeGen/PostRASchedulerList.cpp index 2f7cfdd..badfd9a6 100644 --- a/llvm/lib/CodeGen/PostRASchedulerList.cpp +++ b/llvm/lib/CodeGen/PostRASchedulerList.cpp @@ -98,12 +98,6 @@ namespace { } bool runOnMachineFunction(MachineFunction &Fn) override; - - private: - bool enablePostRAScheduler( - const TargetSubtargetInfo &ST, CodeGenOptLevel OptLevel, - TargetSubtargetInfo::AntiDepBreakMode &Mode, - TargetSubtargetInfo::RegClassVector &CriticalPathRCs) const; }; char PostRAScheduler::ID = 0; @@ -259,13 +253,8 @@ LLVM_DUMP_METHOD void SchedulePostRATDList::dumpSchedule() const { } #endif -bool PostRAScheduler::enablePostRAScheduler( - const TargetSubtargetInfo &ST, CodeGenOptLevel OptLevel, - TargetSubtargetInfo::AntiDepBreakMode &Mode, - TargetSubtargetInfo::RegClassVector &CriticalPathRCs) const { - Mode = ST.getAntiDepBreakMode(); - ST.getCriticalPathRCs(CriticalPathRCs); - +static bool enablePostRAScheduler(const TargetSubtargetInfo &ST, + CodeGenOptLevel OptLevel) { // Check for explicit enable/disable of post-ra scheduling. if (EnablePostRAScheduler.getPosition() > 0) return EnablePostRAScheduler; @@ -278,24 +267,17 @@ bool PostRAScheduler::runOnMachineFunction(MachineFunction &Fn) { if (skipFunction(Fn.getFunction())) return false; - TII = Fn.getSubtarget().getInstrInfo(); - MachineLoopInfo &MLI = getAnalysis<MachineLoopInfoWrapperPass>().getLI(); - AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + const auto &Subtarget = Fn.getSubtarget(); TargetPassConfig *PassConfig = &getAnalysis<TargetPassConfig>(); - - RegClassInfo.runOnMachineFunction(Fn); - - TargetSubtargetInfo::AntiDepBreakMode AntiDepMode = - TargetSubtargetInfo::ANTIDEP_NONE; - SmallVector<const TargetRegisterClass*, 4> CriticalPathRCs; - // Check that post-RA scheduling is enabled for this target. - // This may upgrade the AntiDepMode. - if (!enablePostRAScheduler(Fn.getSubtarget(), PassConfig->getOptLevel(), - AntiDepMode, CriticalPathRCs)) + if (!enablePostRAScheduler(Subtarget, PassConfig->getOptLevel())) return false; - // Check for antidep breaking override... + TII = Subtarget.getInstrInfo(); + MachineLoopInfo &MLI = getAnalysis<MachineLoopInfoWrapperPass>().getLI(); + AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + TargetSubtargetInfo::AntiDepBreakMode AntiDepMode = + Subtarget.getAntiDepBreakMode(); if (EnableAntiDepBreaking.getPosition() > 0) { AntiDepMode = (EnableAntiDepBreaking == "all") ? TargetSubtargetInfo::ANTIDEP_ALL @@ -303,6 +285,9 @@ bool PostRAScheduler::runOnMachineFunction(MachineFunction &Fn) { ? TargetSubtargetInfo::ANTIDEP_CRITICAL : TargetSubtargetInfo::ANTIDEP_NONE); } + SmallVector<const TargetRegisterClass *, 4> CriticalPathRCs; + Subtarget.getCriticalPathRCs(CriticalPathRCs); + RegClassInfo.runOnMachineFunction(Fn); LLVM_DEBUG(dbgs() << "PostRAScheduler\n"); diff --git a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp index 79b0fa6..3ab6315 100644 --- a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp +++ b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp @@ -30,22 +30,22 @@ static bool isValidRegUse(const MachineOperand &MO) { return isValidReg(MO) && MO.isUse(); } -static bool isValidRegUseOf(const MachineOperand &MO, MCRegister PhysReg, +static bool isValidRegUseOf(const MachineOperand &MO, MCRegister Reg, const TargetRegisterInfo *TRI) { if (!isValidRegUse(MO)) return false; - return TRI->regsOverlap(MO.getReg(), PhysReg); + return TRI->regsOverlap(MO.getReg(), Reg); } static bool isValidRegDef(const MachineOperand &MO) { return isValidReg(MO) && MO.isDef(); } -static bool isValidRegDefOf(const MachineOperand &MO, MCRegister PhysReg, +static bool isValidRegDefOf(const MachineOperand &MO, MCRegister Reg, const TargetRegisterInfo *TRI) { if (!isValidRegDef(MO)) return false; - return TRI->regsOverlap(MO.getReg(), PhysReg); + return TRI->regsOverlap(MO.getReg(), Reg); } void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) { @@ -261,7 +261,7 @@ void ReachingDefAnalysis::traverse() { } int ReachingDefAnalysis::getReachingDef(MachineInstr *MI, - MCRegister PhysReg) const { + MCRegister Reg) const { assert(InstIds.count(MI) && "Unexpected machine instuction."); int InstId = InstIds.lookup(MI); int DefRes = ReachingDefDefaultVal; @@ -269,7 +269,7 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI, assert(MBBNumber < MBBReachingDefs.numBlockIDs() && "Unexpected basic block number."); int LatestDef = ReachingDefDefaultVal; - for (MCRegUnit Unit : TRI->regunits(PhysReg)) { + for (MCRegUnit Unit : TRI->regunits(Reg)) { for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) { if (Def >= InstId) break; @@ -280,22 +280,21 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI, return LatestDef; } -MachineInstr * -ReachingDefAnalysis::getReachingLocalMIDef(MachineInstr *MI, - MCRegister PhysReg) const { - return hasLocalDefBefore(MI, PhysReg) - ? getInstFromId(MI->getParent(), getReachingDef(MI, PhysReg)) - : nullptr; +MachineInstr *ReachingDefAnalysis::getReachingLocalMIDef(MachineInstr *MI, + MCRegister Reg) const { + return hasLocalDefBefore(MI, Reg) + ? getInstFromId(MI->getParent(), getReachingDef(MI, Reg)) + : nullptr; } bool ReachingDefAnalysis::hasSameReachingDef(MachineInstr *A, MachineInstr *B, - MCRegister PhysReg) const { + MCRegister Reg) const { MachineBasicBlock *ParentA = A->getParent(); MachineBasicBlock *ParentB = B->getParent(); if (ParentA != ParentB) return false; - return getReachingDef(A, PhysReg) == getReachingDef(B, PhysReg); + return getReachingDef(A, Reg) == getReachingDef(B, Reg); } MachineInstr *ReachingDefAnalysis::getInstFromId(MachineBasicBlock *MBB, @@ -318,19 +317,18 @@ MachineInstr *ReachingDefAnalysis::getInstFromId(MachineBasicBlock *MBB, return nullptr; } -int ReachingDefAnalysis::getClearance(MachineInstr *MI, - MCRegister PhysReg) const { +int ReachingDefAnalysis::getClearance(MachineInstr *MI, MCRegister Reg) const { assert(InstIds.count(MI) && "Unexpected machine instuction."); - return InstIds.lookup(MI) - getReachingDef(MI, PhysReg); + return InstIds.lookup(MI) - getReachingDef(MI, Reg); } bool ReachingDefAnalysis::hasLocalDefBefore(MachineInstr *MI, - MCRegister PhysReg) const { - return getReachingDef(MI, PhysReg) >= 0; + MCRegister Reg) const { + return getReachingDef(MI, Reg) >= 0; } void ReachingDefAnalysis::getReachingLocalUses(MachineInstr *Def, - MCRegister PhysReg, + MCRegister Reg, InstSet &Uses) const { MachineBasicBlock *MBB = Def->getParent(); MachineBasicBlock::iterator MI = MachineBasicBlock::iterator(Def); @@ -340,11 +338,11 @@ void ReachingDefAnalysis::getReachingLocalUses(MachineInstr *Def, // If/when we find a new reaching def, we know that there's no more uses // of 'Def'. - if (getReachingLocalMIDef(&*MI, PhysReg) != Def) + if (getReachingLocalMIDef(&*MI, Reg) != Def) return; for (auto &MO : MI->operands()) { - if (!isValidRegUseOf(MO, PhysReg, TRI)) + if (!isValidRegUseOf(MO, Reg, TRI)) continue; Uses.insert(&*MI); @@ -354,15 +352,14 @@ void ReachingDefAnalysis::getReachingLocalUses(MachineInstr *Def, } } -bool ReachingDefAnalysis::getLiveInUses(MachineBasicBlock *MBB, - MCRegister PhysReg, +bool ReachingDefAnalysis::getLiveInUses(MachineBasicBlock *MBB, MCRegister Reg, InstSet &Uses) const { for (MachineInstr &MI : instructionsWithoutDebug(MBB->instr_begin(), MBB->instr_end())) { for (auto &MO : MI.operands()) { - if (!isValidRegUseOf(MO, PhysReg, TRI)) + if (!isValidRegUseOf(MO, Reg, TRI)) continue; - if (getReachingDef(&MI, PhysReg) >= 0) + if (getReachingDef(&MI, Reg) >= 0) return false; Uses.insert(&MI); } @@ -370,18 +367,18 @@ bool ReachingDefAnalysis::getLiveInUses(MachineBasicBlock *MBB, auto Last = MBB->getLastNonDebugInstr(); if (Last == MBB->end()) return true; - return isReachingDefLiveOut(&*Last, PhysReg); + return isReachingDefLiveOut(&*Last, Reg); } -void ReachingDefAnalysis::getGlobalUses(MachineInstr *MI, MCRegister PhysReg, +void ReachingDefAnalysis::getGlobalUses(MachineInstr *MI, MCRegister Reg, InstSet &Uses) const { MachineBasicBlock *MBB = MI->getParent(); // Collect the uses that each def touches within the block. - getReachingLocalUses(MI, PhysReg, Uses); + getReachingLocalUses(MI, Reg, Uses); // Handle live-out values. - if (auto *LiveOut = getLocalLiveOutMIDef(MI->getParent(), PhysReg)) { + if (auto *LiveOut = getLocalLiveOutMIDef(MI->getParent(), Reg)) { if (LiveOut != MI) return; @@ -389,9 +386,9 @@ void ReachingDefAnalysis::getGlobalUses(MachineInstr *MI, MCRegister PhysReg, SmallPtrSet<MachineBasicBlock*, 4>Visited; while (!ToVisit.empty()) { MachineBasicBlock *MBB = ToVisit.pop_back_val(); - if (Visited.count(MBB) || !MBB->isLiveIn(PhysReg)) + if (Visited.count(MBB) || !MBB->isLiveIn(Reg)) continue; - if (getLiveInUses(MBB, PhysReg, Uses)) + if (getLiveInUses(MBB, Reg, Uses)) llvm::append_range(ToVisit, MBB->successors()); Visited.insert(MBB); } @@ -399,25 +396,25 @@ void ReachingDefAnalysis::getGlobalUses(MachineInstr *MI, MCRegister PhysReg, } void ReachingDefAnalysis::getGlobalReachingDefs(MachineInstr *MI, - MCRegister PhysReg, + MCRegister Reg, InstSet &Defs) const { - if (auto *Def = getUniqueReachingMIDef(MI, PhysReg)) { + if (auto *Def = getUniqueReachingMIDef(MI, Reg)) { Defs.insert(Def); return; } for (auto *MBB : MI->getParent()->predecessors()) - getLiveOuts(MBB, PhysReg, Defs); + getLiveOuts(MBB, Reg, Defs); } -void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, - MCRegister PhysReg, InstSet &Defs) const { +void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, MCRegister Reg, + InstSet &Defs) const { SmallPtrSet<MachineBasicBlock*, 2> VisitedBBs; - getLiveOuts(MBB, PhysReg, Defs, VisitedBBs); + getLiveOuts(MBB, Reg, Defs, VisitedBBs); } -void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, - MCRegister PhysReg, InstSet &Defs, +void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, MCRegister Reg, + InstSet &Defs, BlockSet &VisitedBBs) const { if (VisitedBBs.count(MBB)) return; @@ -425,28 +422,28 @@ void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, VisitedBBs.insert(MBB); LiveRegUnits LiveRegs(*TRI); LiveRegs.addLiveOuts(*MBB); - if (LiveRegs.available(PhysReg)) + if (LiveRegs.available(Reg)) return; - if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg)) + if (auto *Def = getLocalLiveOutMIDef(MBB, Reg)) Defs.insert(Def); else for (auto *Pred : MBB->predecessors()) - getLiveOuts(Pred, PhysReg, Defs, VisitedBBs); + getLiveOuts(Pred, Reg, Defs, VisitedBBs); } MachineInstr * ReachingDefAnalysis::getUniqueReachingMIDef(MachineInstr *MI, - MCRegister PhysReg) const { + MCRegister Reg) const { // If there's a local def before MI, return it. - MachineInstr *LocalDef = getReachingLocalMIDef(MI, PhysReg); + MachineInstr *LocalDef = getReachingLocalMIDef(MI, Reg); if (LocalDef && InstIds.lookup(LocalDef) < InstIds.lookup(MI)) return LocalDef; SmallPtrSet<MachineInstr*, 2> Incoming; MachineBasicBlock *Parent = MI->getParent(); for (auto *Pred : Parent->predecessors()) - getLiveOuts(Pred, PhysReg, Incoming); + getLiveOuts(Pred, Reg, Incoming); // Check that we have a single incoming value and that it does not // come from the same block as MI - since it would mean that the def @@ -469,13 +466,13 @@ MachineInstr *ReachingDefAnalysis::getMIOperand(MachineInstr *MI, } bool ReachingDefAnalysis::isRegUsedAfter(MachineInstr *MI, - MCRegister PhysReg) const { + MCRegister Reg) const { MachineBasicBlock *MBB = MI->getParent(); LiveRegUnits LiveRegs(*TRI); LiveRegs.addLiveOuts(*MBB); // Yes if the register is live out of the basic block. - if (!LiveRegs.available(PhysReg)) + if (!LiveRegs.available(Reg)) return true; // Walk backwards through the block to see if the register is live at some @@ -483,62 +480,61 @@ bool ReachingDefAnalysis::isRegUsedAfter(MachineInstr *MI, for (MachineInstr &Last : instructionsWithoutDebug(MBB->instr_rbegin(), MBB->instr_rend())) { LiveRegs.stepBackward(Last); - if (!LiveRegs.available(PhysReg)) + if (!LiveRegs.available(Reg)) return InstIds.lookup(&Last) > InstIds.lookup(MI); } return false; } bool ReachingDefAnalysis::isRegDefinedAfter(MachineInstr *MI, - MCRegister PhysReg) const { + MCRegister Reg) const { MachineBasicBlock *MBB = MI->getParent(); auto Last = MBB->getLastNonDebugInstr(); if (Last != MBB->end() && - getReachingDef(MI, PhysReg) != getReachingDef(&*Last, PhysReg)) + getReachingDef(MI, Reg) != getReachingDef(&*Last, Reg)) return true; - if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg)) - return Def == getReachingLocalMIDef(MI, PhysReg); + if (auto *Def = getLocalLiveOutMIDef(MBB, Reg)) + return Def == getReachingLocalMIDef(MI, Reg); return false; } bool ReachingDefAnalysis::isReachingDefLiveOut(MachineInstr *MI, - MCRegister PhysReg) const { + MCRegister Reg) const { MachineBasicBlock *MBB = MI->getParent(); LiveRegUnits LiveRegs(*TRI); LiveRegs.addLiveOuts(*MBB); - if (LiveRegs.available(PhysReg)) + if (LiveRegs.available(Reg)) return false; auto Last = MBB->getLastNonDebugInstr(); - int Def = getReachingDef(MI, PhysReg); - if (Last != MBB->end() && getReachingDef(&*Last, PhysReg) != Def) + int Def = getReachingDef(MI, Reg); + if (Last != MBB->end() && getReachingDef(&*Last, Reg) != Def) return false; // Finally check that the last instruction doesn't redefine the register. for (auto &MO : Last->operands()) - if (isValidRegDefOf(MO, PhysReg, TRI)) + if (isValidRegDefOf(MO, Reg, TRI)) return false; return true; } -MachineInstr * -ReachingDefAnalysis::getLocalLiveOutMIDef(MachineBasicBlock *MBB, - MCRegister PhysReg) const { +MachineInstr *ReachingDefAnalysis::getLocalLiveOutMIDef(MachineBasicBlock *MBB, + MCRegister Reg) const { LiveRegUnits LiveRegs(*TRI); LiveRegs.addLiveOuts(*MBB); - if (LiveRegs.available(PhysReg)) + if (LiveRegs.available(Reg)) return nullptr; auto Last = MBB->getLastNonDebugInstr(); if (Last == MBB->end()) return nullptr; - int Def = getReachingDef(&*Last, PhysReg); + int Def = getReachingDef(&*Last, Reg); for (auto &MO : Last->operands()) - if (isValidRegDefOf(MO, PhysReg, TRI)) + if (isValidRegDefOf(MO, Reg, TRI)) return &*Last; return Def < 0 ? nullptr : getInstFromId(MBB, Def); @@ -650,7 +646,7 @@ ReachingDefAnalysis::isSafeToRemove(MachineInstr *MI, InstSet &Visited, void ReachingDefAnalysis::collectKilledOperands(MachineInstr *MI, InstSet &Dead) const { Dead.insert(MI); - auto IsDead = [this, &Dead](MachineInstr *Def, MCRegister PhysReg) { + auto IsDead = [this, &Dead](MachineInstr *Def, MCRegister Reg) { if (mayHaveSideEffects(*Def)) return false; @@ -666,7 +662,7 @@ void ReachingDefAnalysis::collectKilledOperands(MachineInstr *MI, return false; SmallPtrSet<MachineInstr*, 4> Uses; - getGlobalUses(Def, PhysReg, Uses); + getGlobalUses(Def, Reg, Uses); return llvm::set_is_subset(Uses, Dead); }; @@ -680,18 +676,18 @@ void ReachingDefAnalysis::collectKilledOperands(MachineInstr *MI, } bool ReachingDefAnalysis::isSafeToDefRegAt(MachineInstr *MI, - MCRegister PhysReg) const { + MCRegister Reg) const { SmallPtrSet<MachineInstr*, 1> Ignore; - return isSafeToDefRegAt(MI, PhysReg, Ignore); + return isSafeToDefRegAt(MI, Reg, Ignore); } -bool ReachingDefAnalysis::isSafeToDefRegAt(MachineInstr *MI, MCRegister PhysReg, +bool ReachingDefAnalysis::isSafeToDefRegAt(MachineInstr *MI, MCRegister Reg, InstSet &Ignore) const { // Check for any uses of the register after MI. - if (isRegUsedAfter(MI, PhysReg)) { - if (auto *Def = getReachingLocalMIDef(MI, PhysReg)) { + if (isRegUsedAfter(MI, Reg)) { + if (auto *Def = getReachingLocalMIDef(MI, Reg)) { SmallPtrSet<MachineInstr*, 2> Uses; - getGlobalUses(Def, PhysReg, Uses); + getGlobalUses(Def, Reg, Uses); if (!llvm::set_is_subset(Uses, Ignore)) return false; } else @@ -700,13 +696,13 @@ bool ReachingDefAnalysis::isSafeToDefRegAt(MachineInstr *MI, MCRegister PhysReg, MachineBasicBlock *MBB = MI->getParent(); // Check for any defs after MI. - if (isRegDefinedAfter(MI, PhysReg)) { + if (isRegDefinedAfter(MI, Reg)) { auto I = MachineBasicBlock::iterator(MI); for (auto E = MBB->end(); I != E; ++I) { if (Ignore.count(&*I)) continue; for (auto &MO : I->operands()) - if (isValidRegDefOf(MO, PhysReg, TRI)) + if (isValidRegDefOf(MO, Reg, TRI)) return false; } } diff --git a/llvm/lib/CodeGen/RegAllocGreedy.cpp b/llvm/lib/CodeGen/RegAllocGreedy.cpp index 4fa2bc7..b94992c 100644 --- a/llvm/lib/CodeGen/RegAllocGreedy.cpp +++ b/llvm/lib/CodeGen/RegAllocGreedy.cpp @@ -140,7 +140,7 @@ static cl::opt<bool> GreedyReverseLocalAssignment( static cl::opt<unsigned> SplitThresholdForRegWithHint( "split-threshold-for-reg-with-hint", cl::desc("The threshold for splitting a virtual register with a hint, in " - "percentate"), + "percentage"), cl::init(75), cl::Hidden); static RegisterRegAlloc greedyRegAlloc("greedy", "greedy register allocator", @@ -376,6 +376,12 @@ unsigned DefaultPriorityAdvisor::getPriority(const LiveInterval &LI) const { return Prio; } +unsigned DummyPriorityAdvisor::getPriority(const LiveInterval &LI) const { + // Prioritize by virtual register number, lowest first. + Register Reg = LI.reg(); + return ~Reg.virtRegIndex(); +} + const LiveInterval *RAGreedy::dequeue() { return dequeue(Queue); } const LiveInterval *RAGreedy::dequeue(PQueue &CurQueue) { @@ -2029,6 +2035,9 @@ unsigned RAGreedy::tryLastChanceRecoloring(const LiveInterval &VirtReg, // available colors. Matrix->assign(VirtReg, PhysReg); + // VirtReg may be deleted during tryRecoloringCandidates, save a copy. + Register ThisVirtReg = VirtReg.reg(); + // Save the current recoloring state. // If we cannot recolor all the interferences, we will have to start again // at this point for the next physical register. @@ -2040,8 +2049,16 @@ unsigned RAGreedy::tryLastChanceRecoloring(const LiveInterval &VirtReg, NewVRegs.push_back(NewVReg); // Do not mess up with the global assignment process. // I.e., VirtReg must be unassigned. - Matrix->unassign(VirtReg); - return PhysReg; + if (VRM->hasPhys(ThisVirtReg)) { + Matrix->unassign(VirtReg); + return PhysReg; + } + + // It is possible VirtReg will be deleted during tryRecoloringCandidates. + LLVM_DEBUG(dbgs() << "tryRecoloringCandidates deleted a fixed register " + << printReg(ThisVirtReg) << '\n'); + FixedRegisters.erase(ThisVirtReg); + return 0; } LLVM_DEBUG(dbgs() << "Fail to assign: " << VirtReg << " to " diff --git a/llvm/lib/CodeGen/RegAllocPriorityAdvisor.cpp b/llvm/lib/CodeGen/RegAllocPriorityAdvisor.cpp index 0650aaf..4525b8f 100644 --- a/llvm/lib/CodeGen/RegAllocPriorityAdvisor.cpp +++ b/llvm/lib/CodeGen/RegAllocPriorityAdvisor.cpp @@ -30,7 +30,10 @@ static cl::opt<RegAllocPriorityAdvisorAnalysis::AdvisorMode> Mode( clEnumValN(RegAllocPriorityAdvisorAnalysis::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocPriorityAdvisorAnalysis::AdvisorMode::Development, - "development", "for training"))); + "development", "for training"), + clEnumValN( + RegAllocPriorityAdvisorAnalysis::AdvisorMode::Dummy, "dummy", + "prioritize low virtual register numbers for test and debug"))); char RegAllocPriorityAdvisorAnalysis::ID = 0; INITIALIZE_PASS(RegAllocPriorityAdvisorAnalysis, "regalloc-priority", @@ -67,6 +70,31 @@ private: } const bool NotAsRequested; }; + +class DummyPriorityAdvisorAnalysis final + : public RegAllocPriorityAdvisorAnalysis { +public: + DummyPriorityAdvisorAnalysis() + : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Dummy) {} + + // support for isa<> and dyn_cast. + static bool classof(const RegAllocPriorityAdvisorAnalysis *R) { + return R->getAdvisorMode() == AdvisorMode::Dummy; + } + +private: + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<SlotIndexesWrapperPass>(); + RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU); + } + + std::unique_ptr<RegAllocPriorityAdvisor> + getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override { + return std::make_unique<DummyPriorityAdvisor>( + MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI()); + } +}; + } // namespace template <> Pass *llvm::callDefaultCtor<RegAllocPriorityAdvisorAnalysis>() { @@ -75,6 +103,9 @@ template <> Pass *llvm::callDefaultCtor<RegAllocPriorityAdvisorAnalysis>() { case RegAllocPriorityAdvisorAnalysis::AdvisorMode::Default: Ret = new DefaultPriorityAdvisorAnalysis(/*NotAsRequested*/ false); break; + case RegAllocPriorityAdvisorAnalysis::AdvisorMode::Dummy: + Ret = new DummyPriorityAdvisorAnalysis(); + break; case RegAllocPriorityAdvisorAnalysis::AdvisorMode::Development: #if defined(LLVM_HAVE_TFLITE) Ret = createDevelopmentModePriorityAdvisor(); @@ -97,6 +128,8 @@ StringRef RegAllocPriorityAdvisorAnalysis::getPassName() const { return "Release mode Regalloc Priority Advisor"; case AdvisorMode::Development: return "Development mode Regalloc Priority Advisor"; + case AdvisorMode::Dummy: + return "Dummy Regalloc Priority Advisor"; } llvm_unreachable("Unknown advisor kind"); } diff --git a/llvm/lib/CodeGen/RegAllocPriorityAdvisor.h b/llvm/lib/CodeGen/RegAllocPriorityAdvisor.h index 1e9fa96..32e4598 100644 --- a/llvm/lib/CodeGen/RegAllocPriorityAdvisor.h +++ b/llvm/lib/CodeGen/RegAllocPriorityAdvisor.h @@ -56,9 +56,21 @@ private: unsigned getPriority(const LiveInterval &LI) const override; }; +/// Stupid priority advisor which just enqueues in virtual register number +/// order, for debug purposes only. +class DummyPriorityAdvisor : public RegAllocPriorityAdvisor { +public: + DummyPriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA, + SlotIndexes *const Indexes) + : RegAllocPriorityAdvisor(MF, RA, Indexes) {} + +private: + unsigned getPriority(const LiveInterval &LI) const override; +}; + class RegAllocPriorityAdvisorAnalysis : public ImmutablePass { public: - enum class AdvisorMode : int { Default, Release, Development }; + enum class AdvisorMode : int { Default, Release, Development, Dummy }; RegAllocPriorityAdvisorAnalysis(AdvisorMode Mode) : ImmutablePass(ID), Mode(Mode){}; diff --git a/llvm/lib/CodeGen/RegisterCoalescer.cpp b/llvm/lib/CodeGen/RegisterCoalescer.cpp index 20ad644..8313927 100644 --- a/llvm/lib/CodeGen/RegisterCoalescer.cpp +++ b/llvm/lib/CodeGen/RegisterCoalescer.cpp @@ -113,7 +113,7 @@ static cl::opt<unsigned> LargeIntervalSizeThreshold( static cl::opt<unsigned> LargeIntervalFreqThreshold( "large-interval-freq-threshold", cl::Hidden, - cl::desc("For a large interval, if it is coalesed with other live " + cl::desc("For a large interval, if it is coalesced with other live " "intervals many times more than the threshold, stop its " "coalescing to control the compile time. "), cl::init(256)); @@ -1325,11 +1325,6 @@ bool RegisterCoalescer::reMaterializeTrivialDef(const CoalescerPair &CP, const MCInstrDesc &MCID = DefMI->getDesc(); if (MCID.getNumDefs() != 1) return false; - // Only support subregister destinations when the def is read-undef. - MachineOperand &DstOperand = CopyMI->getOperand(0); - Register CopyDstReg = DstOperand.getReg(); - if (DstOperand.getSubReg() && !DstOperand.isUndef()) - return false; // If both SrcIdx and DstIdx are set, correct rematerialization would widen // the register substantially (beyond both source and dest size). This is bad @@ -1339,6 +1334,32 @@ bool RegisterCoalescer::reMaterializeTrivialDef(const CoalescerPair &CP, if (SrcIdx && DstIdx) return false; + // Only support subregister destinations when the def is read-undef. + MachineOperand &DstOperand = CopyMI->getOperand(0); + Register CopyDstReg = DstOperand.getReg(); + if (DstOperand.getSubReg() && !DstOperand.isUndef()) + return false; + + // In the physical register case, checking that the def is read-undef is not + // enough. We're widening the def and need to avoid clobbering other live + // values in the unused register pieces. + // + // TODO: Targets may support rewriting the rematerialized instruction to only + // touch relevant lanes, in which case we don't need any liveness check. + if (CopyDstReg.isPhysical() && CP.isPartial()) { + for (MCRegUnit Unit : TRI->regunits(DstReg)) { + // Ignore the register units we are writing anyway. + if (is_contained(TRI->regunits(CopyDstReg), Unit)) + continue; + + // Check if the other lanes we are defining are live at the + // rematerialization point. + LiveRange &LR = LIS->getRegUnit(Unit); + if (LR.liveAt(CopyIdx)) + return false; + } + } + const unsigned DefSubIdx = DefMI->getOperand(0).getSubReg(); const TargetRegisterClass *DefRC = TII->getRegClass(MCID, 0, TRI, *MF); if (!DefMI->isImplicitDef()) { @@ -1375,27 +1396,6 @@ bool RegisterCoalescer::reMaterializeTrivialDef(const CoalescerPair &CP, NewMI.setDebugLoc(DL); // In a situation like the following: - // - // undef %2.subreg:reg = INST %1:reg ; DefMI (rematerializable), - // ; DefSubIdx = subreg - // %3:reg = COPY %2 ; SrcIdx = DstIdx = 0 - // .... = SOMEINSTR %3:reg - // - // there are no subranges for %3 so after rematerialization we need - // to explicitly create them. Undefined subranges are removed later on. - if (DstReg.isVirtual() && DefSubIdx && !CP.getSrcIdx() && !CP.getDstIdx() && - MRI->shouldTrackSubRegLiveness(DstReg)) { - LiveInterval &DstInt = LIS->getInterval(DstReg); - if (!DstInt.hasSubRanges()) { - LaneBitmask FullMask = MRI->getMaxLaneMaskForVReg(DstReg); - LaneBitmask UsedLanes = TRI->getSubRegIndexLaneMask(DefSubIdx); - LaneBitmask UnusedLanes = FullMask & ~UsedLanes; - DstInt.createSubRangeFrom(LIS->getVNInfoAllocator(), UsedLanes, DstInt); - DstInt.createSubRangeFrom(LIS->getVNInfoAllocator(), UnusedLanes, DstInt); - } - } - - // In a situation like the following: // %0:subreg = instr ; DefMI, subreg = DstIdx // %1 = copy %0:subreg ; CopyMI, SrcIdx = 0 // instead of widening %1 to the register class of %0 simply do: @@ -1523,6 +1523,27 @@ bool RegisterCoalescer::reMaterializeTrivialDef(const CoalescerPair &CP, // sure that "undef" is not set. if (NewIdx == 0) NewMI.getOperand(0).setIsUndef(false); + + // In a situation like the following: + // + // undef %2.subreg:reg = INST %1:reg ; DefMI (rematerializable), + // ; Defines only some of lanes, + // ; so DefSubIdx = NewIdx = subreg + // %3:reg = COPY %2 ; Copy full reg + // .... = SOMEINSTR %3:reg ; Use full reg + // + // there are no subranges for %3 so after rematerialization we need + // to explicitly create them. Undefined subranges are removed later on. + if (NewIdx && !DstInt.hasSubRanges() && + MRI->shouldTrackSubRegLiveness(DstReg)) { + LaneBitmask FullMask = MRI->getMaxLaneMaskForVReg(DstReg); + LaneBitmask UsedLanes = TRI->getSubRegIndexLaneMask(NewIdx); + LaneBitmask UnusedLanes = FullMask & ~UsedLanes; + VNInfo::Allocator &Alloc = LIS->getVNInfoAllocator(); + DstInt.createSubRangeFrom(Alloc, UsedLanes, DstInt); + DstInt.createSubRangeFrom(Alloc, UnusedLanes, DstInt); + } + // Add dead subregister definitions if we are defining the whole register // but only part of it is live. // This could happen if the rematerialization instruction is rematerializing diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 6cbfef2..da3c834 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -141,7 +141,7 @@ static cl::opt<bool> EnableReduceLoadOpStoreWidth( static cl::opt<bool> ReduceLoadOpStoreWidthForceNarrowingProfitable( "combiner-reduce-load-op-store-width-force-narrowing-profitable", cl::Hidden, cl::init(false), - cl::desc("DAG combiner force override the narrowing profitable check when" + cl::desc("DAG combiner force override the narrowing profitable check when " "reducing the width of load/op/store sequences")); static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore( @@ -3949,6 +3949,23 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true)) return Result; + // Similar to the previous rule, but this time targeting an expanded abs. + // (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X)) + // as well as + // (sub 0, (min X, (sub 0, X))) --> (max X, (sub 0, X)) + // Note that these two are applicable to both signed and unsigned min/max. + SDValue X; + SDValue S0; + auto NegPat = m_AllOf(m_Neg(m_Deferred(X)), m_Value(S0)); + if (sd_match(N1, m_OneUse(m_AnyOf(m_SMax(m_Value(X), NegPat), + m_UMax(m_Value(X), NegPat), + m_SMin(m_Value(X), NegPat), + m_UMin(m_Value(X), NegPat))))) { + unsigned NewOpc = ISD::getInverseMinMaxOpcode(N1->getOpcode()); + if (hasOperation(NewOpc, VT)) + return DAG.getNode(NewOpc, DL, VT, X, S0); + } + // Fold neg(splat(neg(x)) -> splat(x) if (VT.isVector()) { SDValue N1S = DAG.getSplatValue(N1, true); @@ -20438,10 +20455,8 @@ SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) { Value.hasOneUse()) { LoadSDNode *LD = cast<LoadSDNode>(Value); EVT VT = LD->getMemoryVT(); - if (!VT.isFloatingPoint() || - VT != ST->getMemoryVT() || - LD->isNonTemporal() || - ST->isNonTemporal() || + if (!VT.isSimple() || !VT.isFloatingPoint() || VT != ST->getMemoryVT() || + LD->isNonTemporal() || ST->isNonTemporal() || LD->getPointerInfo().getAddrSpace() != 0 || ST->getPointerInfo().getAddrSpace() != 0) return SDValue(); @@ -23088,8 +23103,11 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger()) return DAG.getAnyExtOrTrunc(BCSrc, DL, ScalarVT); + // TODO: Add support for SCALAR_TO_VECTOR implicit truncation. if (LegalTypes && BCSrc.getValueType().isInteger() && - BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) { + BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR && + BCSrc.getScalarValueSizeInBits() == + BCSrc.getOperand(0).getScalarValueSizeInBits()) { // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt --> // trunc i64 X to i32 SDValue X = BCSrc.getOperand(0); @@ -24288,8 +24306,8 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) { EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits()); // Keep track of what we encounter. - bool AnyInteger = false; - bool AnyFP = false; + EVT AnyFPVT; + for (const SDValue &Op : N->ops()) { if (ISD::BITCAST == Op.getOpcode() && !Op.getOperand(0).getValueType().isVector()) @@ -24303,27 +24321,23 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) { // If it's neither, bail out, it could be something weird like x86mmx. EVT LastOpVT = Ops.back().getValueType(); if (LastOpVT.isFloatingPoint()) - AnyFP = true; - else if (LastOpVT.isInteger()) - AnyInteger = true; - else + AnyFPVT = LastOpVT; + else if (!LastOpVT.isInteger()) return SDValue(); } // If any of the operands is a floating point scalar bitcast to a vector, // use floating point types throughout, and bitcast everything. // Replace UNDEFs by another scalar UNDEF node, of the final desired type. - if (AnyFP) { - SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits()); - if (AnyInteger) { - for (SDValue &Op : Ops) { - if (Op.getValueType() == SVT) - continue; - if (Op.isUndef()) - Op = DAG.getNode(ISD::UNDEF, DL, SVT); - else - Op = DAG.getBitcast(SVT, Op); - } + if (AnyFPVT != EVT()) { + SVT = AnyFPVT; + for (SDValue &Op : Ops) { + if (Op.getValueType() == SVT) + continue; + if (Op.isUndef()) + Op = DAG.getNode(ISD::UNDEF, DL, SVT); + else + Op = DAG.getBitcast(SVT, Op); } } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index db21e70..89a00c5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -402,6 +402,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { case ISD::FMAXNUM_IEEE: case ISD::FMINIMUM: case ISD::FMAXIMUM: + case ISD::FMINIMUMNUM: + case ISD::FMAXIMUMNUM: case ISD::FCOPYSIGN: case ISD::FSQRT: case ISD::FSIN: @@ -1081,6 +1083,10 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) { case ISD::FMAXIMUM: Results.push_back(TLI.expandFMINIMUM_FMAXIMUM(Node, DAG)); return; + case ISD::FMINIMUMNUM: + case ISD::FMAXIMUMNUM: + Results.push_back(TLI.expandFMINIMUMNUM_FMAXIMUMNUM(Node, DAG)); + return; case ISD::SMIN: case ISD::SMAX: case ISD::UMIN: @@ -1738,7 +1744,8 @@ void VectorLegalizer::ExpandUINT_TO_FLOAT(SDNode *Node, bool IsStrict = Node->isStrictFPOpcode(); unsigned OpNo = IsStrict ? 1 : 0; SDValue Src = Node->getOperand(OpNo); - EVT VT = Src.getValueType(); + EVT SrcVT = Src.getValueType(); + EVT DstVT = Node->getValueType(0); SDLoc DL(Node); // Attempt to expand using TargetLowering. @@ -1752,11 +1759,11 @@ void VectorLegalizer::ExpandUINT_TO_FLOAT(SDNode *Node, } // Make sure that the SINT_TO_FP and SRL instructions are available. - if (((!IsStrict && TLI.getOperationAction(ISD::SINT_TO_FP, VT) == + if (((!IsStrict && TLI.getOperationAction(ISD::SINT_TO_FP, SrcVT) == TargetLowering::Expand) || - (IsStrict && TLI.getOperationAction(ISD::STRICT_SINT_TO_FP, VT) == + (IsStrict && TLI.getOperationAction(ISD::STRICT_SINT_TO_FP, SrcVT) == TargetLowering::Expand)) || - TLI.getOperationAction(ISD::SRL, VT) == TargetLowering::Expand) { + TLI.getOperationAction(ISD::SRL, SrcVT) == TargetLowering::Expand) { if (IsStrict) { UnrollStrictFPOp(Node, Results); return; @@ -1766,37 +1773,59 @@ void VectorLegalizer::ExpandUINT_TO_FLOAT(SDNode *Node, return; } - unsigned BW = VT.getScalarSizeInBits(); + unsigned BW = SrcVT.getScalarSizeInBits(); assert((BW == 64 || BW == 32) && "Elements in vector-UINT_TO_FP must be 32 or 64 bits wide"); - SDValue HalfWord = DAG.getConstant(BW / 2, DL, VT); + // If STRICT_/FMUL is not supported by the target (in case of f16) replace the + // UINT_TO_FP with a larger float and round to the smaller type + if ((!IsStrict && !TLI.isOperationLegalOrCustom(ISD::FMUL, DstVT)) || + (IsStrict && !TLI.isOperationLegalOrCustom(ISD::STRICT_FMUL, DstVT))) { + EVT FPVT = BW == 32 ? MVT::f32 : MVT::f64; + SDValue UIToFP; + SDValue Result; + SDValue TargetZero = DAG.getIntPtrConstant(0, DL, /*isTarget=*/true); + EVT FloatVecVT = SrcVT.changeVectorElementType(FPVT); + if (IsStrict) { + UIToFP = DAG.getNode(ISD::STRICT_UINT_TO_FP, DL, {FloatVecVT, MVT::Other}, + {Node->getOperand(0), Src}); + Result = DAG.getNode(ISD::STRICT_FP_ROUND, DL, {DstVT, MVT::Other}, + {Node->getOperand(0), UIToFP, TargetZero}); + Results.push_back(Result); + Results.push_back(Result.getValue(1)); + } else { + UIToFP = DAG.getNode(ISD::UINT_TO_FP, DL, FloatVecVT, Src); + Result = DAG.getNode(ISD::FP_ROUND, DL, DstVT, UIToFP, TargetZero); + Results.push_back(Result); + } + + return; + } + + SDValue HalfWord = DAG.getConstant(BW / 2, DL, SrcVT); // Constants to clear the upper part of the word. // Notice that we can also use SHL+SHR, but using a constant is slightly // faster on x86. uint64_t HWMask = (BW == 64) ? 0x00000000FFFFFFFF : 0x0000FFFF; - SDValue HalfWordMask = DAG.getConstant(HWMask, DL, VT); + SDValue HalfWordMask = DAG.getConstant(HWMask, DL, SrcVT); // Two to the power of half-word-size. - SDValue TWOHW = - DAG.getConstantFP(1ULL << (BW / 2), DL, Node->getValueType(0)); + SDValue TWOHW = DAG.getConstantFP(1ULL << (BW / 2), DL, DstVT); // Clear upper part of LO, lower HI - SDValue HI = DAG.getNode(ISD::SRL, DL, VT, Src, HalfWord); - SDValue LO = DAG.getNode(ISD::AND, DL, VT, Src, HalfWordMask); + SDValue HI = DAG.getNode(ISD::SRL, DL, SrcVT, Src, HalfWord); + SDValue LO = DAG.getNode(ISD::AND, DL, SrcVT, Src, HalfWordMask); if (IsStrict) { // Convert hi and lo to floats // Convert the hi part back to the upper values // TODO: Can any fast-math-flags be set on these nodes? - SDValue fHI = DAG.getNode(ISD::STRICT_SINT_TO_FP, DL, - {Node->getValueType(0), MVT::Other}, + SDValue fHI = DAG.getNode(ISD::STRICT_SINT_TO_FP, DL, {DstVT, MVT::Other}, {Node->getOperand(0), HI}); - fHI = DAG.getNode(ISD::STRICT_FMUL, DL, {Node->getValueType(0), MVT::Other}, + fHI = DAG.getNode(ISD::STRICT_FMUL, DL, {DstVT, MVT::Other}, {fHI.getValue(1), fHI, TWOHW}); - SDValue fLO = DAG.getNode(ISD::STRICT_SINT_TO_FP, DL, - {Node->getValueType(0), MVT::Other}, + SDValue fLO = DAG.getNode(ISD::STRICT_SINT_TO_FP, DL, {DstVT, MVT::Other}, {Node->getOperand(0), LO}); SDValue TF = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, fHI.getValue(1), @@ -1804,8 +1833,7 @@ void VectorLegalizer::ExpandUINT_TO_FLOAT(SDNode *Node, // Add the two halves SDValue Result = - DAG.getNode(ISD::STRICT_FADD, DL, {Node->getValueType(0), MVT::Other}, - {TF, fHI, fLO}); + DAG.getNode(ISD::STRICT_FADD, DL, {DstVT, MVT::Other}, {TF, fHI, fLO}); Results.push_back(Result); Results.push_back(Result.getValue(1)); @@ -1815,13 +1843,12 @@ void VectorLegalizer::ExpandUINT_TO_FLOAT(SDNode *Node, // Convert hi and lo to floats // Convert the hi part back to the upper values // TODO: Can any fast-math-flags be set on these nodes? - SDValue fHI = DAG.getNode(ISD::SINT_TO_FP, DL, Node->getValueType(0), HI); - fHI = DAG.getNode(ISD::FMUL, DL, Node->getValueType(0), fHI, TWOHW); - SDValue fLO = DAG.getNode(ISD::SINT_TO_FP, DL, Node->getValueType(0), LO); + SDValue fHI = DAG.getNode(ISD::SINT_TO_FP, DL, DstVT, HI); + fHI = DAG.getNode(ISD::FMUL, DL, DstVT, fHI, TWOHW); + SDValue fLO = DAG.getNode(ISD::SINT_TO_FP, DL, DstVT, LO); // Add the two halves - Results.push_back( - DAG.getNode(ISD::FADD, DL, Node->getValueType(0), fHI, fLO)); + Results.push_back(DAG.getNode(ISD::FADD, DL, DstVT, fHI, fLO)); } SDValue VectorLegalizer::ExpandFNEG(SDNode *Node) { @@ -2246,11 +2273,13 @@ SDValue VectorLegalizer::UnrollVSETCC(SDNode *Node) { DAG.getVectorIdxConstant(i, dl)); SDValue RHSElem = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, TmpEltVT, RHS, DAG.getVectorIdxConstant(i, dl)); + // FIXME: We should use i1 setcc + boolext here, but it causes regressions. Ops[i] = DAG.getNode(ISD::SETCC, dl, TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), TmpEltVT), LHSElem, RHSElem, CC); - Ops[i] = DAG.getSelect(dl, EltVT, Ops[i], DAG.getAllOnesConstant(dl, EltVT), + Ops[i] = DAG.getSelect(dl, EltVT, Ops[i], + DAG.getBoolConstant(true, dl, EltVT, VT), DAG.getConstant(0, dl, EltVT)); } return DAG.getBuildVector(VT, dl, Ops); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 107454a..780eba1 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -149,6 +149,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) { case ISD::FMAXNUM_IEEE: case ISD::FMINIMUM: case ISD::FMAXIMUM: + case ISD::FMINIMUMNUM: + case ISD::FMAXIMUMNUM: case ISD::FLDEXP: case ISD::ABDS: case ISD::ABDU: diff --git a/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp b/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp index 9e5867c..51ee3cc 100644 --- a/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp @@ -125,9 +125,9 @@ static cl::opt<int> MaxReorderWindow( cl::desc("Number of instructions to allow ahead of the critical path " "in sched=list-ilp")); -static cl::opt<unsigned> AvgIPC( - "sched-avg-ipc", cl::Hidden, cl::init(1), - cl::desc("Average inst/cycle whan no target itinerary exists.")); +static cl::opt<unsigned> + AvgIPC("sched-avg-ipc", cl::Hidden, cl::init(1), + cl::desc("Average inst/cycle when no target itinerary exists.")); namespace { diff --git a/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp b/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp index 26fc75c..dff7243 100644 --- a/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp @@ -43,9 +43,9 @@ STATISTIC(LoadsClustered, "Number of loads clustered together"); // without a target itinerary. The choice of number here has more to do with // balancing scheduler heuristics than with the actual machine latency. static cl::opt<int> HighLatencyCycles( - "sched-high-latency-cycles", cl::Hidden, cl::init(10), - cl::desc("Roughly estimate the number of cycles that 'long latency'" - "instructions take for targets with no itinerary")); + "sched-high-latency-cycles", cl::Hidden, cl::init(10), + cl::desc("Roughly estimate the number of cycles that 'long latency' " + "instructions take for targets with no itinerary")); ScheduleDAGSDNodes::ScheduleDAGSDNodes(MachineFunction &mf) : ScheduleDAG(mf), InstrItins(mf.getSubtarget().getInstrItineraryData()) {} diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 10e8ba9..0dfd030 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -430,6 +430,21 @@ bool ISD::matchBinaryPredicate( return true; } +ISD::NodeType ISD::getInverseMinMaxOpcode(unsigned MinMaxOpc) { + switch (MinMaxOpc) { + default: + llvm_unreachable("unrecognized opcode"); + case ISD::UMIN: + return ISD::UMAX; + case ISD::UMAX: + return ISD::UMIN; + case ISD::SMIN: + return ISD::SMAX; + case ISD::SMAX: + return ISD::SMIN; + } +} + ISD::NodeType ISD::getVecReduceBaseOpcode(unsigned VecReduceOpcode) { switch (VecReduceOpcode) { default: diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index e87d809..9f57884 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -8435,7 +8435,6 @@ bool TargetLowering::expandUINT_TO_FP(SDNode *Node, SDValue &Result, return false; SDLoc dl(SDValue(Node, 0)); - EVT ShiftVT = getShiftAmountTy(SrcVT, DAG.getDataLayout()); // Implementation of unsigned i64 to f64 following the algorithm in // __floatundidf in compiler_rt. This implementation performs rounding @@ -8448,7 +8447,7 @@ bool TargetLowering::expandUINT_TO_FP(SDNode *Node, SDValue &Result, llvm::bit_cast<double>(UINT64_C(0x4530000000100000)), dl, DstVT); SDValue TwoP84 = DAG.getConstant(UINT64_C(0x4530000000000000), dl, SrcVT); SDValue LoMask = DAG.getConstant(UINT64_C(0x00000000FFFFFFFF), dl, SrcVT); - SDValue HiShift = DAG.getConstant(32, dl, ShiftVT); + SDValue HiShift = DAG.getShiftAmountConstant(32, SrcVT, dl); SDValue Lo = DAG.getNode(ISD::AND, dl, SrcVT, Src, LoMask); SDValue Hi = DAG.getNode(ISD::SRL, dl, SrcVT, Src, HiShift); diff --git a/llvm/lib/CodeGen/StackMapLivenessAnalysis.cpp b/llvm/lib/CodeGen/StackMapLivenessAnalysis.cpp index 687acd9..8437422 100644 --- a/llvm/lib/CodeGen/StackMapLivenessAnalysis.cpp +++ b/llvm/lib/CodeGen/StackMapLivenessAnalysis.cpp @@ -106,8 +106,6 @@ bool StackMapLiveness::runOnMachineFunction(MachineFunction &MF) { if (!EnablePatchPointLiveness) return false; - LLVM_DEBUG(dbgs() << "********** COMPUTING STACKMAP LIVENESS: " - << MF.getName() << " **********\n"); TRI = MF.getSubtarget().getRegisterInfo(); ++NumStackMapFuncVisited; @@ -121,6 +119,8 @@ bool StackMapLiveness::runOnMachineFunction(MachineFunction &MF) { /// Performs the actual liveness calculation for the function. bool StackMapLiveness::calculateLiveness(MachineFunction &MF) { + LLVM_DEBUG(dbgs() << "********** COMPUTING STACKMAP LIVENESS: " + << MF.getName() << " **********\n"); bool HasChanged = false; // For all basic blocks in the function. for (auto &MBB : MF) { diff --git a/llvm/lib/CodeGen/SwiftErrorValueTracking.cpp b/llvm/lib/CodeGen/SwiftErrorValueTracking.cpp index 74a94d6..decffdc 100644 --- a/llvm/lib/CodeGen/SwiftErrorValueTracking.cpp +++ b/llvm/lib/CodeGen/SwiftErrorValueTracking.cpp @@ -259,7 +259,7 @@ void SwiftErrorValueTracking::propagateVRegs() { for (const auto &Use : VRegUpwardsUse) { const MachineBasicBlock *UseBB = Use.first.first; Register VReg = Use.second; - if (!MRI.def_begin(VReg).atEnd()) + if (!MRI.def_empty(VReg)) continue; #ifdef EXPENSIVE_CHECKS diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp index d407e9f..5c05589 100644 --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -113,8 +113,6 @@ static cl::opt<bool> EnableImplicitNullChecks( static cl::opt<bool> DisableMergeICmps("disable-mergeicmps", cl::desc("Disable MergeICmps Pass"), cl::init(false), cl::Hidden); -static cl::opt<bool> PrintLSR("print-lsr-output", cl::Hidden, - cl::desc("Print LLVM IR produced by the loop-reduce pass")); static cl::opt<bool> PrintISelInput("print-isel-input", cl::Hidden, cl::desc("Print LLVM IR input to isel pass")); @@ -503,7 +501,6 @@ CGPassBuilderOption llvm::getCGPassBuilderOption() { SET_BOOLEAN_OPTION(DisableCGP) SET_BOOLEAN_OPTION(DisablePartialLibcallInlining) SET_BOOLEAN_OPTION(DisableSelectOptimize) - SET_BOOLEAN_OPTION(PrintLSR) SET_BOOLEAN_OPTION(PrintISelInput) SET_BOOLEAN_OPTION(DebugifyAndStripAll) SET_BOOLEAN_OPTION(DebugifyCheckAndStripAll) @@ -836,9 +833,6 @@ void TargetPassConfig::addIRPasses() { addPass(createLoopStrengthReducePass()); if (EnableLoopTermFold) addPass(createLoopTermFoldPass()); - if (PrintLSR) - addPass(createPrintFunctionPass(dbgs(), - "\n\n*** Code after LSR ***\n")); } // The MergeICmpsPass tries to create memcmp calls by grouping sequences of diff --git a/llvm/lib/DebugInfo/GSYM/CallSiteInfo.cpp b/llvm/lib/DebugInfo/GSYM/CallSiteInfo.cpp index 85b41e2..c112c0b 100644 --- a/llvm/lib/DebugInfo/GSYM/CallSiteInfo.cpp +++ b/llvm/lib/DebugInfo/GSYM/CallSiteInfo.cpp @@ -151,7 +151,7 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(FunctionYAML) Error CallSiteInfoLoader::loadYAML(StringRef YAMLFile) { // Step 1: Read YAML file - auto BufferOrError = MemoryBuffer::getFile(YAMLFile); + auto BufferOrError = MemoryBuffer::getFile(YAMLFile, /*IsText=*/true); if (!BufferOrError) return errorCodeToError(BufferOrError.getError()); diff --git a/llvm/lib/DebugInfo/GSYM/FunctionInfo.cpp b/llvm/lib/DebugInfo/GSYM/FunctionInfo.cpp index dd754c70..785a8da 100644 --- a/llvm/lib/DebugInfo/GSYM/FunctionInfo.cpp +++ b/llvm/lib/DebugInfo/GSYM/FunctionInfo.cpp @@ -235,10 +235,10 @@ llvm::Expected<uint64_t> FunctionInfo::encode(FileWriter &Out, return FuncInfoOffset; } -llvm::Expected<LookupResult> FunctionInfo::lookup(DataExtractor &Data, - const GsymReader &GR, - uint64_t FuncAddr, - uint64_t Addr) { +llvm::Expected<LookupResult> +FunctionInfo::lookup(DataExtractor &Data, const GsymReader &GR, + uint64_t FuncAddr, uint64_t Addr, + std::optional<DataExtractor> *MergedFuncsData) { LookupResult LR; LR.LookupAddr = Addr; uint64_t Offset = 0; @@ -289,6 +289,12 @@ llvm::Expected<LookupResult> FunctionInfo::lookup(DataExtractor &Data, return ExpectedLE.takeError(); break; + case InfoType::MergedFunctionsInfo: + // Store the merged functions data for later parsing, if needed. + if (MergedFuncsData) + *MergedFuncsData = InfoData; + break; + case InfoType::InlineInfo: // We will parse the inline info after our line table, but only if // we have a line entry. diff --git a/llvm/lib/DebugInfo/GSYM/GsymReader.cpp b/llvm/lib/DebugInfo/GSYM/GsymReader.cpp index fa5476d..0a5bb7c 100644 --- a/llvm/lib/DebugInfo/GSYM/GsymReader.cpp +++ b/llvm/lib/DebugInfo/GSYM/GsymReader.cpp @@ -334,14 +334,52 @@ GsymReader::getFunctionInfoAtIndex(uint64_t Idx) const { return ExpectedData.takeError(); } -llvm::Expected<LookupResult> GsymReader::lookup(uint64_t Addr) const { +llvm::Expected<LookupResult> +GsymReader::lookup(uint64_t Addr, + std::optional<DataExtractor> *MergedFunctionsData) const { uint64_t FuncStartAddr = 0; if (auto ExpectedData = getFunctionInfoDataForAddress(Addr, FuncStartAddr)) - return FunctionInfo::lookup(*ExpectedData, *this, FuncStartAddr, Addr); + return FunctionInfo::lookup(*ExpectedData, *this, FuncStartAddr, Addr, + MergedFunctionsData); else return ExpectedData.takeError(); } +llvm::Expected<std::vector<LookupResult>> +GsymReader::lookupAll(uint64_t Addr) const { + std::vector<LookupResult> Results; + std::optional<DataExtractor> MergedFunctionsData; + + // First perform a lookup to get the primary function info result. + auto MainResult = lookup(Addr, &MergedFunctionsData); + if (!MainResult) + return MainResult.takeError(); + + // Add the main result as the first entry. + Results.push_back(std::move(*MainResult)); + + // Now process any merged functions data that was found during the lookup. + if (MergedFunctionsData) { + // Get data extractors for each merged function. + auto ExpectedMergedFuncExtractors = + MergedFunctionsInfo::getFuncsDataExtractors(*MergedFunctionsData); + if (!ExpectedMergedFuncExtractors) + return ExpectedMergedFuncExtractors.takeError(); + + // Process each merged function data. + for (DataExtractor &MergedData : *ExpectedMergedFuncExtractors) { + if (auto FI = FunctionInfo::lookup(MergedData, *this, + MainResult->FuncRange.start(), Addr)) { + Results.push_back(std::move(*FI)); + } else { + return FI.takeError(); + } + } + } + + return Results; +} + void GsymReader::dump(raw_ostream &OS) { const auto &Header = getHeader(); // Dump the GSYM header. diff --git a/llvm/lib/DebugInfo/GSYM/MergedFunctionsInfo.cpp b/llvm/lib/DebugInfo/GSYM/MergedFunctionsInfo.cpp index 4efae22..d2c28f3 100644 --- a/llvm/lib/DebugInfo/GSYM/MergedFunctionsInfo.cpp +++ b/llvm/lib/DebugInfo/GSYM/MergedFunctionsInfo.cpp @@ -35,22 +35,59 @@ llvm::Error MergedFunctionsInfo::encode(FileWriter &Out) const { llvm::Expected<MergedFunctionsInfo> MergedFunctionsInfo::decode(DataExtractor &Data, uint64_t BaseAddr) { MergedFunctionsInfo MFI; + auto FuncExtractorsOrError = MFI.getFuncsDataExtractors(Data); + + if (!FuncExtractorsOrError) + return FuncExtractorsOrError.takeError(); + + for (DataExtractor &FuncData : *FuncExtractorsOrError) { + llvm::Expected<FunctionInfo> FI = FunctionInfo::decode(FuncData, BaseAddr); + if (!FI) + return FI.takeError(); + MFI.MergedFunctions.push_back(std::move(*FI)); + } + + return MFI; +} + +llvm::Expected<std::vector<DataExtractor>> +MergedFunctionsInfo::getFuncsDataExtractors(DataExtractor &Data) { + std::vector<DataExtractor> Results; uint64_t Offset = 0; + + // Ensure there is enough data to read the function count. + if (!Data.isValidOffsetForDataOfSize(Offset, 4)) + return createStringError( + std::errc::io_error, + "unable to read the function count at offset 0x%8.8" PRIx64, Offset); + uint32_t Count = Data.getU32(&Offset); for (uint32_t i = 0; i < Count; ++i) { + // Ensure there is enough data to read the function size. + if (!Data.isValidOffsetForDataOfSize(Offset, 4)) + return createStringError( + std::errc::io_error, + "unable to read size of function %u at offset 0x%8.8" PRIx64, i, + Offset); + uint32_t FnSize = Data.getU32(&Offset); - DataExtractor FnData(Data.getData().substr(Offset, FnSize), + + // Ensure there is enough data for the function content. + if (!Data.isValidOffsetForDataOfSize(Offset, FnSize)) + return createStringError( + std::errc::io_error, + "function data is truncated for function %u at offset 0x%8.8" PRIx64 + ", expected size %u", + i, Offset, FnSize); + + // Extract the function data. + Results.emplace_back(Data.getData().substr(Offset, FnSize), Data.isLittleEndian(), Data.getAddressSize()); - llvm::Expected<FunctionInfo> FI = - FunctionInfo::decode(FnData, BaseAddr + Offset); - if (!FI) - return FI.takeError(); - MFI.MergedFunctions.push_back(std::move(*FI)); + Offset += FnSize; } - - return MFI; + return Results; } bool operator==(const MergedFunctionsInfo &LHS, diff --git a/llvm/lib/ExecutionEngine/JITLink/ELF_loongarch.cpp b/llvm/lib/ExecutionEngine/JITLink/ELF_loongarch.cpp index 56c32ae..a12e9f3 100644 --- a/llvm/lib/ExecutionEngine/JITLink/ELF_loongarch.cpp +++ b/llvm/lib/ExecutionEngine/JITLink/ELF_loongarch.cpp @@ -58,6 +58,10 @@ private: return Pointer32; case ELF::R_LARCH_32_PCREL: return Delta32; + case ELF::R_LARCH_B16: + return Branch16PCRel; + case ELF::R_LARCH_B21: + return Branch21PCRel; case ELF::R_LARCH_B26: return Branch26PCRel; case ELF::R_LARCH_PCALA_HI20: diff --git a/llvm/lib/ExecutionEngine/JITLink/loongarch.cpp b/llvm/lib/ExecutionEngine/JITLink/loongarch.cpp index 010c0ed..cdb3da0 100644 --- a/llvm/lib/ExecutionEngine/JITLink/loongarch.cpp +++ b/llvm/lib/ExecutionEngine/JITLink/loongarch.cpp @@ -44,6 +44,8 @@ const char *getEdgeKindName(Edge::Kind K) { KIND_NAME_CASE(Delta32) KIND_NAME_CASE(NegDelta32) KIND_NAME_CASE(Delta64) + KIND_NAME_CASE(Branch16PCRel) + KIND_NAME_CASE(Branch21PCRel) KIND_NAME_CASE(Branch26PCRel) KIND_NAME_CASE(Page20) KIND_NAME_CASE(PageOffset12) diff --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp index 6a9ebb4..d47eb44 100644 --- a/llvm/lib/ExecutionEngine/Orc/Core.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp @@ -1576,12 +1576,22 @@ void Platform::lookupInitSymbolsAsync( } } +MaterializationTask::~MaterializationTask() { + // If this task wasn't run then fail materialization. + if (MR) + MR->failMaterialization(); +} + void MaterializationTask::printDescription(raw_ostream &OS) { OS << "Materialization task: " << MU->getName() << " in " << MR->getTargetJITDylib().getName(); } -void MaterializationTask::run() { MU->materialize(std::move(MR)); } +void MaterializationTask::run() { + assert(MU && "MU should not be null"); + assert(MR && "MR should not be null"); + MU->materialize(std::move(MR)); +} void LookupTask::printDescription(raw_ostream &OS) { OS << "Lookup task"; } @@ -1821,17 +1831,10 @@ ExecutionSession::lookup(const JITDylibSearchOrder &SearchOrder, RegisterDependenciesFunction RegisterDependencies) { #if LLVM_ENABLE_THREADS // In the threaded case we use promises to return the results. - std::promise<SymbolMap> PromisedResult; - Error ResolutionError = Error::success(); + std::promise<MSVCPExpected<SymbolMap>> PromisedResult; auto NotifyComplete = [&](Expected<SymbolMap> R) { - if (R) - PromisedResult.set_value(std::move(*R)); - else { - ErrorAsOutParameter _(ResolutionError); - ResolutionError = R.takeError(); - PromisedResult.set_value(SymbolMap()); - } + PromisedResult.set_value(std::move(R)); }; #else @@ -1848,18 +1851,11 @@ ExecutionSession::lookup(const JITDylibSearchOrder &SearchOrder, #endif // Perform the asynchronous lookup. - lookup(K, SearchOrder, std::move(Symbols), RequiredState, NotifyComplete, - RegisterDependencies); + lookup(K, SearchOrder, std::move(Symbols), RequiredState, + std::move(NotifyComplete), RegisterDependencies); #if LLVM_ENABLE_THREADS - auto ResultFuture = PromisedResult.get_future(); - auto Result = ResultFuture.get(); - - if (ResolutionError) - return std::move(ResolutionError); - - return std::move(Result); - + return PromisedResult.get_future().get(); #else if (ResolutionError) return std::move(ResolutionError); diff --git a/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp b/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp index c08e52e..0d9a912 100644 --- a/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp @@ -148,7 +148,7 @@ public: DSec.BuilderSec->align = Log2_64(SR.getFirstBlock()->getAlignment()); StringRef SectionData(SR.getFirstBlock()->getContent().data(), SR.getFirstBlock()->getSize()); - DebugSectionMap[SecName] = + DebugSectionMap[SecName.drop_front(2)] = // drop "__" prefix. MemoryBuffer::getMemBuffer(SectionData, G.getName(), false); if (SecName == "__debug_line") DebugLineSectionData = SectionData; @@ -167,11 +167,10 @@ public: DebugLineSectionData, G.getEndianness() == llvm::endianness::little, G.getPointerSize()); uint64_t Offset = 0; - DWARFDebugLine::LineTable LineTable; + DWARFDebugLine::Prologue P; // Try to parse line data. Consume error on failure. - if (auto Err = LineTable.parse(DebugLineData, &Offset, *DWARFCtx, nullptr, - consumeError)) { + if (auto Err = P.parse(DebugLineData, &Offset, consumeError, *DWARFCtx)) { handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { LLVM_DEBUG({ dbgs() << "Cannot parse line table for \"" << G.getName() << "\": "; @@ -180,15 +179,26 @@ public: }); }); } else { - if (!LineTable.Prologue.FileNames.empty()) - FileName = *dwarf::toString(LineTable.Prologue.FileNames[0].Name); + for (auto &FN : P.FileNames) + if ((FileName = dwarf::toString(FN.Name))) { + LLVM_DEBUG({ + dbgs() << "Using FileName = \"" << *FileName + << "\" from DWARF line table\n"; + }); + break; + } } } // If no line table (or unable to use) then use graph name. // FIXME: There are probably other debug sections we should look in first. - if (!FileName) - FileName = StringRef(G.getName()); + if (!FileName) { + LLVM_DEBUG({ + dbgs() << "Could not find source name from DWARF line table. " + "Using FileName = \"\"\n"; + }); + FileName = ""; + } Builder.addSymbol("", MachO::N_SO, 0, 0, 0); Builder.addSymbol(*FileName, MachO::N_SO, 0, 0, 0); diff --git a/llvm/lib/ExecutionEngine/Orc/MachOPlatform.cpp b/llvm/lib/ExecutionEngine/Orc/MachOPlatform.cpp index 0e83497..9f324c7 100644 --- a/llvm/lib/ExecutionEngine/Orc/MachOPlatform.cpp +++ b/llvm/lib/ExecutionEngine/Orc/MachOPlatform.cpp @@ -937,6 +937,12 @@ Error MachOPlatform::MachOPlatformPlugin::bootstrapPipelineEnd( jitlink::LinkGraph &G) { std::lock_guard<std::mutex> Lock(MP.Bootstrap.load()->Mutex); assert(MP.Bootstrap && "DeferredAAs reset before bootstrap completed"); + + // Transfer any allocation actions to DeferredAAs. + std::move(G.allocActions().begin(), G.allocActions().end(), + std::back_inserter(MP.Bootstrap.load()->DeferredAAs)); + G.allocActions().clear(); + --MP.Bootstrap.load()->ActiveGraphs; // Notify Bootstrap->CV while holding the mutex because the mutex is // also keeping Bootstrap->CV alive. @@ -1397,10 +1403,6 @@ Error MachOPlatform::MachOPlatformPlugin::registerObjectPlatformSections( SPSExecutorAddrRange, SPSExecutorAddrRange>>, SPSSequence<SPSTuple<SPSString, SPSExecutorAddrRange>>>; - shared::AllocActions &allocActions = LLVM_LIKELY(!InBootstrapPhase) - ? G.allocActions() - : MP.Bootstrap.load()->DeferredAAs; - ExecutorAddr HeaderAddr; { std::lock_guard<std::mutex> Lock(MP.PlatformMutex); @@ -1410,7 +1412,7 @@ Error MachOPlatform::MachOPlatformPlugin::registerObjectPlatformSections( assert(I->second && "Null header registered for JD"); HeaderAddr = I->second; } - allocActions.push_back( + G.allocActions().push_back( {cantFail( WrapperFunctionCall::Create<SPSRegisterObjectPlatformSectionsArgs>( MP.RegisterObjectPlatformSections.Addr, HeaderAddr, UnwindInfo, diff --git a/llvm/lib/ExecutionEngine/Orc/ObjectLinkingLayer.cpp b/llvm/lib/ExecutionEngine/Orc/ObjectLinkingLayer.cpp index 6688b09..9bc0aa8 100644 --- a/llvm/lib/ExecutionEngine/Orc/ObjectLinkingLayer.cpp +++ b/llvm/lib/ExecutionEngine/Orc/ObjectLinkingLayer.cpp @@ -16,8 +16,6 @@ namespace llvm::orc { char ObjectLinkingLayer::ID; -using BaseObjectLayer = RTTIExtends<ObjectLinkingLayer, ObjectLayer>; - void ObjectLinkingLayer::emit(std::unique_ptr<MaterializationResponsibility> R, std::unique_ptr<MemoryBuffer> O) { assert(O && "Object must not be null"); diff --git a/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp b/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp index fbe4b09..1af17e8 100644 --- a/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp +++ b/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp @@ -31,6 +31,10 @@ void DynamicThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) { { std::lock_guard<std::mutex> Lock(DispatchMutex); + // Reject new tasks if they're dispatched after a call to shutdown. + if (Shutdown) + return; + if (IsMaterializationTask) { // If this is a materialization task and there are too many running @@ -54,6 +58,14 @@ void DynamicThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) { // Run the task. T->run(); + // Reset the task to free any resources. We need this to happen *before* + // we notify anyone (via Outstanding) that this thread is done to ensure + // that we don't proceed with JIT shutdown while still holding resources. + // (E.g. this was causing "Dangling SymbolStringPtr" assertions). + T.reset(); + + // Check the work queue state and either proceed with the next task or + // end this thread. std::lock_guard<std::mutex> Lock(DispatchMutex); if (!MaterializationTaskQueue.empty()) { // If there are any materialization tasks running then steal that work. @@ -64,7 +76,6 @@ void DynamicThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) { IsMaterializationTask = true; } } else { - // Otherwise decrement work counters. if (IsMaterializationTask) --NumMaterializationThreads; --Outstanding; @@ -78,7 +89,7 @@ void DynamicThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) { void DynamicThreadPoolTaskDispatcher::shutdown() { std::unique_lock<std::mutex> Lock(DispatchMutex); - Running = false; + Shutdown = true; OutstandingCV.wait(Lock, [this]() { return Outstanding == 0; }); } #endif diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 0d8dbbe..8dbf2aa 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -5302,10 +5302,11 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, Loop *L = LI.getLoopFor(CanonicalLoop->getHeader()); if (AlignedVars.size()) { InsertPointTy IP = Builder.saveIP(); - Builder.SetInsertPoint(CanonicalLoop->getPreheader()->getTerminator()); for (auto &AlignedItem : AlignedVars) { Value *AlignedPtr = AlignedItem.first; Value *Alignment = AlignedItem.second; + Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr); + Builder.SetInsertPoint(loadInst->getNextNode()); Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr, Alignment); } diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp index d81a292..3566435 100644 --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -1520,15 +1520,72 @@ ConstantRange ConstantRange::binaryNot() const { return ConstantRange(APInt::getAllOnes(getBitWidth())).sub(*this); } +/// Estimate the 'bit-masked AND' operation's lower bound. +/// +/// E.g., given two ranges as follows (single quotes are separators and +/// have no meaning here), +/// +/// LHS = [10'00101'1, ; LLo +/// 10'10000'0] ; LHi +/// RHS = [10'11111'0, ; RLo +/// 10'11111'1] ; RHi +/// +/// we know that the higher 2 bits of the result is always 10; and we also +/// notice that RHS[1:6] are always 1, so the result[1:6] cannot be less than +/// LHS[1:6] (i.e., 00101). Thus, the lower bound is 10'00101'0. +/// +/// The algorithm is as follows, +/// 1. we first calculate a mask to find the higher common bits by +/// Mask = ~((LLo ^ LHi) | (RLo ^ RHi) | (LLo ^ RLo)); +/// Mask = clear all non-leading-ones bits in Mask; +/// in the example, the Mask is set to 11'00000'0; +/// 2. calculate a new mask by setting all common leading bits to 1 in RHS, and +/// keeping the longest leading ones (i.e., 11'11111'0 in the example); +/// 3. return (LLo & new mask) as the lower bound; +/// 4. repeat the step 2 and 3 with LHS and RHS swapped, and update the lower +/// bound with the larger one. +static APInt estimateBitMaskedAndLowerBound(const ConstantRange &LHS, + const ConstantRange &RHS) { + auto BitWidth = LHS.getBitWidth(); + // If either is full set or unsigned wrapped, then the range must contain '0' + // which leads the lower bound to 0. + if ((LHS.isFullSet() || RHS.isFullSet()) || + (LHS.isWrappedSet() || RHS.isWrappedSet())) + return APInt::getZero(BitWidth); + + auto LLo = LHS.getLower(); + auto LHi = LHS.getUpper() - 1; + auto RLo = RHS.getLower(); + auto RHi = RHS.getUpper() - 1; + + // Calculate the mask for the higher common bits. + auto Mask = ~((LLo ^ LHi) | (RLo ^ RHi) | (LLo ^ RLo)); + unsigned LeadingOnes = Mask.countLeadingOnes(); + Mask.clearLowBits(BitWidth - LeadingOnes); + + auto estimateBound = [BitWidth, &Mask](APInt ALo, const APInt &BLo, + const APInt &BHi) { + unsigned LeadingOnes = ((BLo & BHi) | Mask).countLeadingOnes(); + unsigned StartBit = BitWidth - LeadingOnes; + ALo.clearLowBits(StartBit); + return ALo; + }; + + auto LowerBoundByLHS = estimateBound(LLo, RLo, RHi); + auto LowerBoundByRHS = estimateBound(RLo, LLo, LHi); + + return APIntOps::umax(LowerBoundByLHS, LowerBoundByRHS); +} + ConstantRange ConstantRange::binaryAnd(const ConstantRange &Other) const { if (isEmptySet() || Other.isEmptySet()) return getEmpty(); ConstantRange KnownBitsRange = fromKnownBits(toKnownBits() & Other.toKnownBits(), false); - ConstantRange UMinUMaxRange = - getNonEmpty(APInt::getZero(getBitWidth()), - APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1); + auto LowerBound = estimateBitMaskedAndLowerBound(*this, Other); + ConstantRange UMinUMaxRange = getNonEmpty( + LowerBound, APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1); return KnownBitsRange.intersectWith(UMinUMaxRange); } @@ -1538,10 +1595,17 @@ ConstantRange ConstantRange::binaryOr(const ConstantRange &Other) const { ConstantRange KnownBitsRange = fromKnownBits(toKnownBits() | Other.toKnownBits(), false); + + // ~a & ~b >= x + // <=> ~(~a & ~b) <= ~x + // <=> a | b <= ~x + // <=> a | b < ~x + 1 = -x + // thus, UpperBound(a | b) == -LowerBound(~a & ~b) + auto UpperBound = + -estimateBitMaskedAndLowerBound(binaryNot(), Other.binaryNot()); // Upper wrapped range. - ConstantRange UMaxUMinRange = - getNonEmpty(APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()), - APInt::getZero(getBitWidth())); + ConstantRange UMaxUMinRange = getNonEmpty( + APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()), UpperBound); return KnownBitsRange.intersectWith(UMaxUMinRange); } diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h index aaaab0b..08bf3f9 100644 --- a/llvm/lib/IR/ConstantsContext.h +++ b/llvm/lib/IR/ConstantsContext.h @@ -491,8 +491,7 @@ public: default: if (Instruction::isCast(Opcode)) return new CastConstantExpr(Opcode, Ops[0], Ty); - if ((Opcode >= Instruction::BinaryOpsBegin && - Opcode < Instruction::BinaryOpsEnd)) + if (Instruction::isBinaryOp(Opcode)) return new BinaryConstantExpr(Opcode, Ops[0], Ops[1], SubclassOptionalData); llvm_unreachable("Invalid ConstantExpr!"); diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp index f340f7a..27b499e 100644 --- a/llvm/lib/IR/IRBuilder.cpp +++ b/llvm/lib/IR/IRBuilder.cpp @@ -78,11 +78,11 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const { CallInst * IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops, - const Twine &Name, Instruction *FMFSource, + const Twine &Name, FMFSource FMFSource, ArrayRef<OperandBundleDef> OpBundles) { CallInst *CI = CreateCall(Callee, Ops, OpBundles, Name); - if (FMFSource) - CI->copyFastMathFlags(FMFSource); + if (isa<FPMathOperator>(CI)) + CI->setFastMathFlags(FMFSource.get(FMF)); return CI; } @@ -869,7 +869,7 @@ CallInst *IRBuilderBase::CreateGCGetPointerOffset(Value *DerivedPtr, } CallInst *IRBuilderBase::CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V, - Instruction *FMFSource, + FMFSource FMFSource, const Twine &Name) { Module *M = BB->getModule(); Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, {V->getType()}); @@ -877,12 +877,12 @@ CallInst *IRBuilderBase::CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V, } Value *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, - Value *RHS, Instruction *FMFSource, + Value *RHS, FMFSource FMFSource, const Twine &Name) { Module *M = BB->getModule(); Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, {LHS->getType()}); if (Value *V = Folder.FoldBinaryIntrinsic(ID, LHS, RHS, Fn->getReturnType(), - FMFSource)) + /*FMFSource=*/nullptr)) return V; return createCallHelper(Fn, {LHS, RHS}, Name, FMFSource); } @@ -890,7 +890,7 @@ Value *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, CallInst *IRBuilderBase::CreateIntrinsic(Intrinsic::ID ID, ArrayRef<Type *> Types, ArrayRef<Value *> Args, - Instruction *FMFSource, + FMFSource FMFSource, const Twine &Name) { Module *M = BB->getModule(); Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, Types); @@ -899,7 +899,7 @@ CallInst *IRBuilderBase::CreateIntrinsic(Intrinsic::ID ID, CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID, ArrayRef<Value *> Args, - Instruction *FMFSource, + FMFSource FMFSource, const Twine &Name) { Module *M = BB->getModule(); @@ -925,16 +925,13 @@ CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID, } CallInst *IRBuilderBase::CreateConstrainedFPBinOp( - Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource, - const Twine &Name, MDNode *FPMathTag, - std::optional<RoundingMode> Rounding, + Intrinsic::ID ID, Value *L, Value *R, FMFSource FMFSource, + const Twine &Name, MDNode *FPMathTag, std::optional<RoundingMode> Rounding, std::optional<fp::ExceptionBehavior> Except) { Value *RoundingV = getConstrainedFPRounding(Rounding); Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMF; - if (FMFSource) - UseFMF = FMFSource->getFastMathFlags(); + FastMathFlags UseFMF = FMFSource.get(FMF); CallInst *C = CreateIntrinsic(ID, {L->getType()}, {L, R, RoundingV, ExceptV}, nullptr, Name); @@ -944,14 +941,12 @@ CallInst *IRBuilderBase::CreateConstrainedFPBinOp( } CallInst *IRBuilderBase::CreateConstrainedFPUnroundedBinOp( - Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource, + Intrinsic::ID ID, Value *L, Value *R, FMFSource FMFSource, const Twine &Name, MDNode *FPMathTag, std::optional<fp::ExceptionBehavior> Except) { Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMF; - if (FMFSource) - UseFMF = FMFSource->getFastMathFlags(); + FastMathFlags UseFMF = FMFSource.get(FMF); CallInst *C = CreateIntrinsic(ID, {L->getType()}, {L, R, ExceptV}, nullptr, Name); @@ -976,15 +971,12 @@ Value *IRBuilderBase::CreateNAryOp(unsigned Opc, ArrayRef<Value *> Ops, } CallInst *IRBuilderBase::CreateConstrainedFPCast( - Intrinsic::ID ID, Value *V, Type *DestTy, - Instruction *FMFSource, const Twine &Name, MDNode *FPMathTag, - std::optional<RoundingMode> Rounding, + Intrinsic::ID ID, Value *V, Type *DestTy, FMFSource FMFSource, + const Twine &Name, MDNode *FPMathTag, std::optional<RoundingMode> Rounding, std::optional<fp::ExceptionBehavior> Except) { Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMF; - if (FMFSource) - UseFMF = FMFSource->getFastMathFlags(); + FastMathFlags UseFMF = FMFSource.get(FMF); CallInst *C; if (Intrinsic::hasConstrainedFPRoundingModeOperand(ID)) { @@ -1002,9 +994,10 @@ CallInst *IRBuilderBase::CreateConstrainedFPCast( return C; } -Value *IRBuilderBase::CreateFCmpHelper( - CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name, - MDNode *FPMathTag, bool IsSignaling) { +Value *IRBuilderBase::CreateFCmpHelper(CmpInst::Predicate P, Value *LHS, + Value *RHS, const Twine &Name, + MDNode *FPMathTag, FMFSource FMFSource, + bool IsSignaling) { if (IsFPConstrained) { auto ID = IsSignaling ? Intrinsic::experimental_constrained_fcmps : Intrinsic::experimental_constrained_fcmp; @@ -1013,7 +1006,9 @@ Value *IRBuilderBase::CreateFCmpHelper( if (auto *V = Folder.FoldCmp(P, LHS, RHS)) return V; - return Insert(setFPAttrs(new FCmpInst(P, LHS, RHS), FPMathTag, FMF), Name); + return Insert( + setFPAttrs(new FCmpInst(P, LHS, RHS), FPMathTag, FMFSource.get(FMF)), + Name); } CallInst *IRBuilderBase::CreateConstrainedFPCmp( @@ -1047,6 +1042,12 @@ CallInst *IRBuilderBase::CreateConstrainedFPCall( Value *IRBuilderBase::CreateSelect(Value *C, Value *True, Value *False, const Twine &Name, Instruction *MDFrom) { + return CreateSelectFMF(C, True, False, {}, Name, MDFrom); +} + +Value *IRBuilderBase::CreateSelectFMF(Value *C, Value *True, Value *False, + FMFSource FMFSource, const Twine &Name, + Instruction *MDFrom) { if (auto *V = Folder.FoldSelect(C, True, False)) return V; @@ -1057,7 +1058,7 @@ Value *IRBuilderBase::CreateSelect(Value *C, Value *True, Value *False, Sel = addBranchMetadata(Sel, Prof, Unpred); } if (isa<FPMathOperator>(Sel)) - setFPAttrs(Sel, nullptr /* MDNode* */, FMF); + setFPAttrs(Sel, /*MDNode=*/nullptr, FMFSource.get(FMF)); return Insert(Sel, Name); } diff --git a/llvm/lib/IR/SafepointIRVerifier.cpp b/llvm/lib/IR/SafepointIRVerifier.cpp index ed99d05..d32852b 100644 --- a/llvm/lib/IR/SafepointIRVerifier.cpp +++ b/llvm/lib/IR/SafepointIRVerifier.cpp @@ -289,6 +289,7 @@ static void PrintValueSet(raw_ostream &OS, IteratorTy Begin, IteratorTy End) { using AvailableValueSet = DenseSet<const Value *>; +namespace { /// State we compute and track per basic block. struct BasicBlockState { // Set of values available coming in, before the phi nodes @@ -305,6 +306,7 @@ struct BasicBlockState { // contribute to AvailableOut. bool Cleared = false; }; +} // namespace /// A given derived pointer can have multiple base pointers through phi/selects. /// This type indicates when the base pointer is exclusively constant diff --git a/llvm/lib/LTO/ThinLTOCodeGenerator.cpp b/llvm/lib/LTO/ThinLTOCodeGenerator.cpp index 4522f4a..189f287 100644 --- a/llvm/lib/LTO/ThinLTOCodeGenerator.cpp +++ b/llvm/lib/LTO/ThinLTOCodeGenerator.cpp @@ -160,8 +160,7 @@ generateModuleMap(std::vector<std::unique_ptr<lto::InputFile>> &Modules) { static void promoteModule(Module &TheModule, const ModuleSummaryIndex &Index, bool ClearDSOLocalOnDeclarations) { - if (renameModuleForThinLTO(TheModule, Index, ClearDSOLocalOnDeclarations)) - report_fatal_error("renameModuleForThinLTO failed"); + renameModuleForThinLTO(TheModule, Index, ClearDSOLocalOnDeclarations); } namespace { diff --git a/llvm/lib/Linker/IRMover.cpp b/llvm/lib/Linker/IRMover.cpp index a0c3f2c..be3535a 100644 --- a/llvm/lib/Linker/IRMover.cpp +++ b/llvm/lib/Linker/IRMover.cpp @@ -1562,10 +1562,6 @@ Error IRLinker::run() { bool EnableDLWarning = true; bool EnableTripleWarning = true; if (SrcTriple.isNVPTX() && DstTriple.isNVPTX()) { - std::string ModuleId = SrcM->getModuleIdentifier(); - StringRef FileName = llvm::sys::path::filename(ModuleId); - bool SrcIsLibDevice = - FileName.starts_with("libdevice") && FileName.ends_with(".10.bc"); bool SrcHasLibDeviceDL = (SrcM->getDataLayoutStr().empty() || SrcM->getDataLayoutStr() == "e-i64:64-v16:16-v32:32-n16:32:64"); @@ -1576,8 +1572,8 @@ Error IRLinker::run() { SrcTriple.getOSName() == "gpulibs") || (SrcTriple.getVendorName() == "unknown" && SrcTriple.getOSName() == "unknown"); - EnableTripleWarning = !(SrcIsLibDevice && SrcHasLibDeviceTriple); - EnableDLWarning = !(SrcIsLibDevice && SrcHasLibDeviceDL); + EnableTripleWarning = !SrcHasLibDeviceTriple; + EnableDLWarning = !(SrcHasLibDeviceTriple && SrcHasLibDeviceDL); } if (EnableDLWarning && (SrcM->getDataLayout() != DstM.getDataLayout())) { diff --git a/llvm/lib/ObjCopy/COFF/COFFObjcopy.cpp b/llvm/lib/ObjCopy/COFF/COFFObjcopy.cpp index 782d5b2..cebcb82 100644 --- a/llvm/lib/ObjCopy/COFF/COFFObjcopy.cpp +++ b/llvm/lib/ObjCopy/COFF/COFFObjcopy.cpp @@ -183,10 +183,18 @@ static Error handleArgs(const CommonConfig &Config, }); if (Config.OnlyKeepDebug) { + const data_directory *DebugDir = + Obj.DataDirectories.size() > DEBUG_DIRECTORY + ? &Obj.DataDirectories[DEBUG_DIRECTORY] + : nullptr; // For --only-keep-debug, we keep all other sections, but remove their // content. The VirtualSize field in the section header is kept intact. - Obj.truncateSections([](const Section &Sec) { + Obj.truncateSections([DebugDir](const Section &Sec) { return !isDebugSection(Sec) && Sec.Name != ".buildid" && + !(DebugDir && DebugDir->Size > 0 && + DebugDir->RelativeVirtualAddress >= Sec.Header.VirtualAddress && + DebugDir->RelativeVirtualAddress < + Sec.Header.VirtualAddress + Sec.Header.SizeOfRawData) && ((Sec.Header.Characteristics & (IMAGE_SCN_CNT_CODE | IMAGE_SCN_CNT_INITIALIZED_DATA)) != 0); }); diff --git a/llvm/lib/ObjCopy/MachO/MachOLayoutBuilder.cpp b/llvm/lib/ObjCopy/MachO/MachOLayoutBuilder.cpp index 93bc663..d4eb6a9b 100644 --- a/llvm/lib/ObjCopy/MachO/MachOLayoutBuilder.cpp +++ b/llvm/lib/ObjCopy/MachO/MachOLayoutBuilder.cpp @@ -116,6 +116,11 @@ uint64_t MachOLayoutBuilder::layoutSegments() { const bool IsObjectFile = O.Header.FileType == MachO::HeaderFileType::MH_OBJECT; uint64_t Offset = IsObjectFile ? (HeaderSize + O.Header.SizeOfCmds) : 0; + if (O.EncryptionInfoCommandIndex) { + // If we are emitting an encryptable binary, our load commands must have a + // separate (non-encrypted) page to themselves. + Offset = alignToPowerOf2(HeaderSize + O.Header.SizeOfCmds, PageSize); + } for (LoadCommand &LC : O.LoadCommands) { auto &MLC = LC.MachOLoadCommand; StringRef Segname; diff --git a/llvm/lib/ObjCopy/MachO/MachOObject.cpp b/llvm/lib/ObjCopy/MachO/MachOObject.cpp index 8d2c02d..e0819d8 100644 --- a/llvm/lib/ObjCopy/MachO/MachOObject.cpp +++ b/llvm/lib/ObjCopy/MachO/MachOObject.cpp @@ -98,6 +98,10 @@ void Object::updateLoadCommandIndexes() { case MachO::LC_DYLD_EXPORTS_TRIE: ExportsTrieCommandIndex = Index; break; + case MachO::LC_ENCRYPTION_INFO: + case MachO::LC_ENCRYPTION_INFO_64: + EncryptionInfoCommandIndex = Index; + break; } } } diff --git a/llvm/lib/ObjCopy/MachO/MachOObject.h b/llvm/lib/ObjCopy/MachO/MachOObject.h index a454c4f..79eb0133 100644 --- a/llvm/lib/ObjCopy/MachO/MachOObject.h +++ b/llvm/lib/ObjCopy/MachO/MachOObject.h @@ -341,6 +341,9 @@ struct Object { /// The index of the LC_SEGMENT or LC_SEGMENT_64 load command /// corresponding to the __TEXT segment. std::optional<size_t> TextSegmentCommandIndex; + /// The index of the LC_ENCRYPTION_INFO or LC_ENCRYPTION_INFO_64 load command + /// if present. + std::optional<size_t> EncryptionInfoCommandIndex; BumpPtrAllocator Alloc; StringSaver NewSectionsContents; diff --git a/llvm/lib/ObjCopy/MachO/MachOReader.cpp b/llvm/lib/ObjCopy/MachO/MachOReader.cpp index 2b344f3..ef0e026 100644 --- a/llvm/lib/ObjCopy/MachO/MachOReader.cpp +++ b/llvm/lib/ObjCopy/MachO/MachOReader.cpp @@ -184,6 +184,10 @@ Error MachOReader::readLoadCommands(Object &O) const { case MachO::LC_DYLD_CHAINED_FIXUPS: O.ChainedFixupsCommandIndex = O.LoadCommands.size(); break; + case MachO::LC_ENCRYPTION_INFO: + case MachO::LC_ENCRYPTION_INFO_64: + O.EncryptionInfoCommandIndex = O.LoadCommands.size(); + break; } #define HANDLE_LOAD_COMMAND(LCName, LCValue, LCStruct) \ case MachO::LCName: \ diff --git a/llvm/lib/Object/WindowsMachineFlag.cpp b/llvm/lib/Object/WindowsMachineFlag.cpp index b9f8187..caf357e 100644 --- a/llvm/lib/Object/WindowsMachineFlag.cpp +++ b/llvm/lib/Object/WindowsMachineFlag.cpp @@ -21,6 +21,7 @@ using namespace llvm; // Returns /machine's value. COFF::MachineTypes llvm::getMachineType(StringRef S) { + // Flags must be a superset of Microsoft lib.exe /machine flags. return StringSwitch<COFF::MachineTypes>(S.lower()) .Cases("x64", "amd64", COFF::IMAGE_FILE_MACHINE_AMD64) .Cases("x86", "i386", COFF::IMAGE_FILE_MACHINE_I386) @@ -28,6 +29,7 @@ COFF::MachineTypes llvm::getMachineType(StringRef S) { .Case("arm64", COFF::IMAGE_FILE_MACHINE_ARM64) .Case("arm64ec", COFF::IMAGE_FILE_MACHINE_ARM64EC) .Case("arm64x", COFF::IMAGE_FILE_MACHINE_ARM64X) + .Case("mips", COFF::IMAGE_FILE_MACHINE_R4000) .Default(COFF::IMAGE_FILE_MACHINE_UNKNOWN); } diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index a936f53..30b8d7c 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -492,6 +492,9 @@ PassBuilder::PassBuilder(TargetMachine *TM, PipelineTuningOptions PTO, PIC->addClassToPassName(decltype(CREATE_PASS)::name(), NAME); #define MACHINE_FUNCTION_PASS(NAME, CREATE_PASS) \ PIC->addClassToPassName(decltype(CREATE_PASS)::name(), NAME); +#define MACHINE_FUNCTION_PASS_WITH_PARAMS(NAME, CLASS, CREATE_PASS, PARSER, \ + PARAMS) \ + PIC->addClassToPassName(CLASS, NAME); #include "llvm/Passes/MachinePassRegistry.def" }); } diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index d737ea5..4ec0fb8 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -189,9 +189,9 @@ static cl::opt<bool> EnableGlobalAnalyses( "enable-global-analyses", cl::init(true), cl::Hidden, cl::desc("Enable inter-procedural analyses")); -static cl::opt<bool> - RunPartialInlining("enable-partial-inlining", cl::init(false), cl::Hidden, - cl::desc("Run Partial inlinining pass")); +static cl::opt<bool> RunPartialInlining("enable-partial-inlining", + cl::init(false), cl::Hidden, + cl::desc("Run Partial inlining pass")); static cl::opt<bool> ExtraVectorizerPasses( "extra-vectorizer-passes", cl::init(false), cl::Hidden, @@ -264,7 +264,7 @@ static cl::opt<bool> static cl::opt<bool> FlattenedProfileUsed( "flattened-profile-used", cl::init(false), cl::Hidden, cl::desc("Indicate the sample profile being used is flattened, i.e., " - "no inline hierachy exists in the profile")); + "no inline hierarchy exists in the profile")); static cl::opt<bool> EnableOrderFileInstrumentation( "enable-order-file-instrumentation", cl::init(false), cl::Hidden, diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 9f0b092..13e192f 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -156,7 +156,7 @@ MODULE_PASS("strip-nonlinetable-debuginfo", StripNonLineTableDebugInfoPass()) MODULE_PASS("trigger-crash-module", TriggerCrashModulePass()) MODULE_PASS("trigger-verifier-error", TriggerVerifierErrorPass()) MODULE_PASS("tsan-module", ModuleThreadSanitizerPass()) -MODULE_PASS("tysan-module", ModuleTypeSanitizerPass()) +MODULE_PASS("tysan", TypeSanitizerPass()) MODULE_PASS("verify", VerifierPass()) MODULE_PASS("view-callgraph", CallGraphViewerPass()) MODULE_PASS("wholeprogramdevirt", WholeProgramDevirtPass()) @@ -481,7 +481,6 @@ FUNCTION_PASS("transform-warning", WarnMissedTransformationsPass()) FUNCTION_PASS("trigger-crash-function", TriggerCrashFunctionPass()) FUNCTION_PASS("trigger-verifier-error", TriggerVerifierErrorPass()) FUNCTION_PASS("tsan", ThreadSanitizerPass()) -FUNCTION_PASS("tysan", TypeSanitizerPass()) FUNCTION_PASS("typepromotion", TypePromotionPass(TM)) FUNCTION_PASS("unify-loop-exits", UnifyLoopExitsPass()) FUNCTION_PASS("vector-combine", VectorCombinePass()) diff --git a/llvm/lib/ProfileData/Coverage/CoverageMapping.cpp b/llvm/lib/ProfileData/Coverage/CoverageMapping.cpp index 5d5dcce..83fe5f0 100644 --- a/llvm/lib/ProfileData/Coverage/CoverageMapping.cpp +++ b/llvm/lib/ProfileData/Coverage/CoverageMapping.cpp @@ -246,6 +246,40 @@ Expected<int64_t> CounterMappingContext::evaluate(const Counter &C) const { return LastPoppedValue; } +// Find an independence pair for each condition: +// - The condition is true in one test and false in the other. +// - The decision outcome is true one test and false in the other. +// - All other conditions' values must be equal or marked as "don't care". +void MCDCRecord::findIndependencePairs() { + if (IndependencePairs) + return; + + IndependencePairs.emplace(); + + unsigned NumTVs = TV.size(); + // Will be replaced to shorter expr. + unsigned TVTrueIdx = std::distance( + TV.begin(), + std::find_if(TV.begin(), TV.end(), + [&](auto I) { return (I.second == MCDCRecord::MCDC_True); }) + + ); + for (unsigned I = TVTrueIdx; I < NumTVs; ++I) { + const auto &[A, ACond] = TV[I]; + assert(ACond == MCDCRecord::MCDC_True); + for (unsigned J = 0; J < TVTrueIdx; ++J) { + const auto &[B, BCond] = TV[J]; + assert(BCond == MCDCRecord::MCDC_False); + // If the two vectors differ in exactly one condition, ignoring DontCare + // conditions, we have found an independence pair. + auto AB = A.getDifferences(B); + if (AB.count() == 1) + IndependencePairs->insert( + {AB.find_first(), std::make_pair(J + 1, I + 1)}); + } + } +} + mcdc::TVIdxBuilder::TVIdxBuilder(const SmallVectorImpl<ConditionIDs> &NextIDs, int Offset) : Indices(NextIDs.size()) { @@ -400,9 +434,6 @@ class MCDCRecordProcessor : NextIDsBuilder, mcdc::TVIdxBuilder { /// ExecutedTestVectorBitmap. MCDCRecord::TestVectors &ExecVectors; - /// Number of False items in ExecVectors - unsigned NumExecVectorsF; - #ifndef NDEBUG DenseSet<unsigned> TVIdxs; #endif @@ -472,34 +503,11 @@ private: // Fill ExecVectors order by False items and True items. // ExecVectors is the alias of ExecVectorsByCond[false], so // Append ExecVectorsByCond[true] on it. - NumExecVectorsF = ExecVectors.size(); auto &ExecVectorsT = ExecVectorsByCond[true]; ExecVectors.append(std::make_move_iterator(ExecVectorsT.begin()), std::make_move_iterator(ExecVectorsT.end())); } - // Find an independence pair for each condition: - // - The condition is true in one test and false in the other. - // - The decision outcome is true one test and false in the other. - // - All other conditions' values must be equal or marked as "don't care". - void findIndependencePairs() { - unsigned NumTVs = ExecVectors.size(); - for (unsigned I = NumExecVectorsF; I < NumTVs; ++I) { - const auto &[A, ACond] = ExecVectors[I]; - assert(ACond == MCDCRecord::MCDC_True); - for (unsigned J = 0; J < NumExecVectorsF; ++J) { - const auto &[B, BCond] = ExecVectors[J]; - assert(BCond == MCDCRecord::MCDC_False); - // If the two vectors differ in exactly one condition, ignoring DontCare - // conditions, we have found an independence pair. - auto AB = A.getDifferences(B); - if (AB.count() == 1) - IndependencePairs.insert( - {AB.find_first(), std::make_pair(J + 1, I + 1)}); - } - } - } - public: /// Process the MC/DC Record in order to produce a result for a boolean /// expression. This process includes tracking the conditions that comprise @@ -535,13 +543,8 @@ public: // Using Profile Bitmap from runtime, mark the executed test vectors. findExecutedTestVectors(); - // Compare executed test vectors against each other to find an independence - // pairs for each condition. This processing takes the most time. - findIndependencePairs(); - // Record Test vectors, executed vectors, and independence pairs. - return MCDCRecord(Region, std::move(ExecVectors), - std::move(IndependencePairs), std::move(Folded), + return MCDCRecord(Region, std::move(ExecVectors), std::move(Folded), std::move(PosToID), std::move(CondLoc)); } }; diff --git a/llvm/lib/ProfileData/MemProfReader.cpp b/llvm/lib/ProfileData/MemProfReader.cpp index 10c36f2..6a4fecd 100644 --- a/llvm/lib/ProfileData/MemProfReader.cpp +++ b/llvm/lib/ProfileData/MemProfReader.cpp @@ -754,7 +754,7 @@ Error RawMemProfReader::readNextRecord( Expected<std::unique_ptr<YAMLMemProfReader>> YAMLMemProfReader::create(const Twine &Path) { - auto BufferOr = MemoryBuffer::getFileOrSTDIN(Path); + auto BufferOr = MemoryBuffer::getFileOrSTDIN(Path, /*IsText=*/true); if (std::error_code EC = BufferOr.getError()) return report(errorCodeToError(EC), Path.getSingleStringRef()); @@ -770,7 +770,7 @@ YAMLMemProfReader::create(std::unique_ptr<MemoryBuffer> Buffer) { } bool YAMLMemProfReader::hasFormat(const StringRef Path) { - auto BufferOr = MemoryBuffer::getFileOrSTDIN(Path); + auto BufferOr = MemoryBuffer::getFileOrSTDIN(Path, /*IsText=*/true); if (!BufferOr) return false; diff --git a/llvm/lib/Support/Windows/Path.inc b/llvm/lib/Support/Windows/Path.inc index 17db114c..5b311e7 100644 --- a/llvm/lib/Support/Windows/Path.inc +++ b/llvm/lib/Support/Windows/Path.inc @@ -1373,9 +1373,11 @@ std::error_code closeFile(file_t &F) { } std::error_code remove_directories(const Twine &path, bool IgnoreErrors) { + SmallString<128> NativePath; + llvm::sys::path::native(path, NativePath, path::Style::windows_backslash); // Convert to utf-16. SmallVector<wchar_t, 128> Path16; - std::error_code EC = widenPath(path, Path16); + std::error_code EC = widenPath(NativePath, Path16); if (EC && !IgnoreErrors) return EC; diff --git a/llvm/lib/TableGen/TGLexer.cpp b/llvm/lib/TableGen/TGLexer.cpp index eee4251..e23aec6 100644 --- a/llvm/lib/TableGen/TGLexer.cpp +++ b/llvm/lib/TableGen/TGLexer.cpp @@ -81,8 +81,7 @@ TGLexer::TGLexer(SourceMgr &SM, ArrayRef<std::string> Macros) : SrcMgr(SM) { TokStart = nullptr; // Pretend that we enter the "top-level" include file. - PrepIncludeStack.push_back( - std::make_unique<std::vector<PreprocessorControlDesc>>()); + PrepIncludeStack.emplace_back(); // Add all macros defined on the command line to the DefinedMacros set. // Check invalid macro names and print fatal error if we find one. @@ -453,8 +452,7 @@ bool TGLexer::LexInclude() { CurBuf = SrcMgr.getMemoryBuffer(CurBuffer)->getBuffer(); CurPtr = CurBuf.begin(); - PrepIncludeStack.push_back( - std::make_unique<std::vector<PreprocessorControlDesc>>()); + PrepIncludeStack.emplace_back(); return false; } @@ -656,17 +654,13 @@ tgtok::TokKind TGLexer::LexExclaim() { bool TGLexer::prepExitInclude(bool IncludeStackMustBeEmpty) { // Report an error, if preprocessor control stack for the current // file is not empty. - if (!PrepIncludeStack.back()->empty()) { + if (!PrepIncludeStack.back().empty()) { prepReportPreprocessorStackError(); return false; } // Pop the preprocessing controls from the include stack. - if (PrepIncludeStack.empty()) { - PrintFatalError("preprocessor include stack is empty"); - } - PrepIncludeStack.pop_back(); if (IncludeStackMustBeEmpty) { @@ -761,7 +755,7 @@ tgtok::TokKind TGLexer::lexPreprocessor(tgtok::TokKind Kind, // Regardless of whether we are processing tokens or not, // we put the #ifdef control on stack. // Note that MacroIsDefined has been canonicalized against ifdef. - PrepIncludeStack.back()->push_back( + PrepIncludeStack.back().push_back( {tgtok::Ifdef, MacroIsDefined, SMLoc::getFromPointer(TokStart)}); if (!prepSkipDirectiveEnd()) @@ -789,10 +783,10 @@ tgtok::TokKind TGLexer::lexPreprocessor(tgtok::TokKind Kind, } else if (Kind == tgtok::Else) { // Check if this #else is correct before calling prepSkipDirectiveEnd(), // which will move CurPtr away from the beginning of #else. - if (PrepIncludeStack.back()->empty()) + if (PrepIncludeStack.back().empty()) return ReturnError(TokStart, "#else without #ifdef or #ifndef"); - PreprocessorControlDesc IfdefEntry = PrepIncludeStack.back()->back(); + PreprocessorControlDesc IfdefEntry = PrepIncludeStack.back().back(); if (IfdefEntry.Kind != tgtok::Ifdef) { PrintError(TokStart, "double #else"); @@ -801,9 +795,8 @@ tgtok::TokKind TGLexer::lexPreprocessor(tgtok::TokKind Kind, // Replace the corresponding #ifdef's control with its negation // on the control stack. - PrepIncludeStack.back()->pop_back(); - PrepIncludeStack.back()->push_back( - {Kind, !IfdefEntry.IsDefined, SMLoc::getFromPointer(TokStart)}); + PrepIncludeStack.back().back() = {Kind, !IfdefEntry.IsDefined, + SMLoc::getFromPointer(TokStart)}; if (!prepSkipDirectiveEnd()) return ReturnError(CurPtr, "only comments are supported after #else"); @@ -822,10 +815,10 @@ tgtok::TokKind TGLexer::lexPreprocessor(tgtok::TokKind Kind, } else if (Kind == tgtok::Endif) { // Check if this #endif is correct before calling prepSkipDirectiveEnd(), // which will move CurPtr away from the beginning of #endif. - if (PrepIncludeStack.back()->empty()) + if (PrepIncludeStack.back().empty()) return ReturnError(TokStart, "#endif without #ifdef"); - auto &IfdefOrElseEntry = PrepIncludeStack.back()->back(); + auto &IfdefOrElseEntry = PrepIncludeStack.back().back(); if (IfdefOrElseEntry.Kind != tgtok::Ifdef && IfdefOrElseEntry.Kind != tgtok::Else) { @@ -836,7 +829,7 @@ tgtok::TokKind TGLexer::lexPreprocessor(tgtok::TokKind Kind, if (!prepSkipDirectiveEnd()) return ReturnError(CurPtr, "only comments are supported after #endif"); - PrepIncludeStack.back()->pop_back(); + PrepIncludeStack.back().pop_back(); // If we were processing tokens before this #endif, then // we should continue it. @@ -1055,20 +1048,16 @@ bool TGLexer::prepSkipDirectiveEnd() { } bool TGLexer::prepIsProcessingEnabled() { - for (const PreprocessorControlDesc &I : - llvm::reverse(*PrepIncludeStack.back())) - if (!I.IsDefined) - return false; - - return true; + return all_of(PrepIncludeStack.back(), + [](const PreprocessorControlDesc &I) { return I.IsDefined; }); } void TGLexer::prepReportPreprocessorStackError() { - if (PrepIncludeStack.back()->empty()) + if (PrepIncludeStack.back().empty()) PrintFatalError("prepReportPreprocessorStackError() called with " "empty control stack"); - auto &PrepControl = PrepIncludeStack.back()->back(); + auto &PrepControl = PrepIncludeStack.back().back(); PrintError(CurBuf.end(), "reached EOF without matching #endif"); PrintError(PrepControl.SrcPos, "the latest preprocessor control is here"); diff --git a/llvm/lib/TableGen/TGLexer.h b/llvm/lib/TableGen/TGLexer.h index 963d75e..f8b32dc 100644 --- a/llvm/lib/TableGen/TGLexer.h +++ b/llvm/lib/TableGen/TGLexer.h @@ -13,6 +13,7 @@ #ifndef LLVM_LIB_TABLEGEN_TGLEXER_H #define LLVM_LIB_TABLEGEN_TGLEXER_H +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/DataTypes.h" @@ -21,7 +22,6 @@ #include <memory> #include <set> #include <string> -#include <vector> namespace llvm { template <typename T> class ArrayRef; @@ -323,8 +323,7 @@ private: // preprocessing control stacks for the current file and all its // parent files. The back() element is the preprocessing control // stack for the current file. - std::vector<std::unique_ptr<std::vector<PreprocessorControlDesc>>> - PrepIncludeStack; + SmallVector<SmallVector<PreprocessorControlDesc>> PrepIncludeStack; // Validate that the current preprocessing control stack is empty, // since we are about to exit a file, and pop the include stack. diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp index e867943..60ae11b 100644 --- a/llvm/lib/TableGen/TGParser.cpp +++ b/llvm/lib/TableGen/TGParser.cpp @@ -776,13 +776,14 @@ ParseSubClassReference(Record *CurRec, bool isDefm) { return Result; } - if (ParseTemplateArgValueList(Result.TemplateArgs, CurRec, Result.Rec)) { + SmallVector<SMLoc> ArgLocs; + if (ParseTemplateArgValueList(Result.TemplateArgs, ArgLocs, CurRec, + Result.Rec)) { Result.Rec = nullptr; // Error parsing value list. return Result; } - if (CheckTemplateArgValues(Result.TemplateArgs, Result.RefRange.Start, - Result.Rec)) { + if (CheckTemplateArgValues(Result.TemplateArgs, ArgLocs, Result.Rec)) { Result.Rec = nullptr; // Error checking value list. return Result; } @@ -812,7 +813,8 @@ ParseSubMultiClassReference(MultiClass *CurMC) { return Result; } - if (ParseTemplateArgValueList(Result.TemplateArgs, &CurMC->Rec, + SmallVector<SMLoc> ArgLocs; + if (ParseTemplateArgValueList(Result.TemplateArgs, ArgLocs, &CurMC->Rec, &Result.MC->Rec)) { Result.MC = nullptr; // Error parsing value list. return Result; @@ -2722,11 +2724,12 @@ const Init *TGParser::ParseSimpleValue(Record *CurRec, const RecTy *ItemType, } SmallVector<const ArgumentInit *, 8> Args; + SmallVector<SMLoc> ArgLocs; Lex.Lex(); // consume the < - if (ParseTemplateArgValueList(Args, CurRec, Class)) + if (ParseTemplateArgValueList(Args, ArgLocs, CurRec, Class)) return nullptr; // Error parsing value list. - if (CheckTemplateArgValues(Args, NameLoc.Start, Class)) + if (CheckTemplateArgValues(Args, ArgLocs, Class)) return nullptr; // Error checking template argument values. if (resolveArguments(Class, Args, NameLoc.Start)) @@ -3201,8 +3204,8 @@ void TGParser::ParseValueList(SmallVectorImpl<const Init *> &Result, // PostionalArgValueList ::= [Value {',' Value}*] // NamedArgValueList ::= [NameValue '=' Value {',' NameValue '=' Value}*] bool TGParser::ParseTemplateArgValueList( - SmallVectorImpl<const ArgumentInit *> &Result, Record *CurRec, - const Record *ArgsRec) { + SmallVectorImpl<const ArgumentInit *> &Result, + SmallVectorImpl<SMLoc> &ArgLocs, Record *CurRec, const Record *ArgsRec) { assert(Result.empty() && "Result vector is not empty"); ArrayRef<const Init *> TArgs = ArgsRec->getTemplateArgs(); @@ -3217,7 +3220,7 @@ bool TGParser::ParseTemplateArgValueList( return true; } - SMLoc ValueLoc = Lex.getLoc(); + SMLoc ValueLoc = ArgLocs.emplace_back(Lex.getLoc()); // If we are parsing named argument, we don't need to know the argument name // and argument type will be resolved after we know the name. const Init *Value = ParseValue( @@ -4417,11 +4420,15 @@ bool TGParser::ParseFile() { // If necessary, replace an argument with a cast to the required type. // The argument count has already been checked. bool TGParser::CheckTemplateArgValues( - SmallVectorImpl<const ArgumentInit *> &Values, SMLoc Loc, + SmallVectorImpl<const ArgumentInit *> &Values, ArrayRef<SMLoc> ValuesLocs, const Record *ArgsRec) { + assert(Values.size() == ValuesLocs.size() && + "expected as many values as locations"); + ArrayRef<const Init *> TArgs = ArgsRec->getTemplateArgs(); - for (const ArgumentInit *&Value : Values) { + bool HasError = false; + for (auto [Value, Loc] : llvm::zip_equal(Values, ValuesLocs)) { const Init *ArgName = nullptr; if (Value->isPositional()) ArgName = TArgs[Value->getIndex()]; @@ -4439,16 +4446,16 @@ bool TGParser::CheckTemplateArgValues( "result of template arg value cast has wrong type"); Value = Value->cloneWithValue(CastValue); } else { - PrintFatalError(Loc, "Value specified for template argument '" + - Arg->getNameInitAsString() + "' is of type " + - ArgValue->getType()->getAsString() + - "; expected type " + ArgType->getAsString() + - ": " + ArgValue->getAsString()); + HasError |= Error( + Loc, "Value specified for template argument '" + + Arg->getNameInitAsString() + "' is of type " + + ArgValue->getType()->getAsString() + "; expected type " + + ArgType->getAsString() + ": " + ArgValue->getAsString()); } } } - return false; + return HasError; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) diff --git a/llvm/lib/TableGen/TGParser.h b/llvm/lib/TableGen/TGParser.h index cac1ba8..4509893 100644 --- a/llvm/lib/TableGen/TGParser.h +++ b/llvm/lib/TableGen/TGParser.h @@ -296,6 +296,7 @@ private: // Parser methods. void ParseValueList(SmallVectorImpl<const Init *> &Result, Record *CurRec, const RecTy *ItemType = nullptr); bool ParseTemplateArgValueList(SmallVectorImpl<const ArgumentInit *> &Result, + SmallVectorImpl<SMLoc> &ArgLocs, Record *CurRec, const Record *ArgsRec); void ParseDagArgList( SmallVectorImpl<std::pair<const Init *, const StringInit *>> &Result, @@ -321,7 +322,8 @@ private: // Parser methods. bool ApplyLetStack(Record *CurRec); bool ApplyLetStack(RecordsEntry &Entry); bool CheckTemplateArgValues(SmallVectorImpl<const ArgumentInit *> &Values, - SMLoc Loc, const Record *ArgsRec); + ArrayRef<SMLoc> ValuesLocs, + const Record *ArgsRec); }; } // end namespace llvm diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index 69d07f2..9bec782 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -2730,6 +2730,54 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { EmitToStreamer(*OutStreamer, TmpInstSB); return; } + case AArch64::TLSDESC_AUTH_CALLSEQ: { + /// lower this to: + /// adrp x0, :tlsdesc_auth:var + /// ldr x16, [x0, #:tlsdesc_auth_lo12:var] + /// add x0, x0, #:tlsdesc_auth_lo12:var + /// blraa x16, x0 + /// (TPIDR_EL0 offset now in x0) + const MachineOperand &MO_Sym = MI->getOperand(0); + MachineOperand MO_TLSDESC_LO12(MO_Sym), MO_TLSDESC(MO_Sym); + MCOperand SymTLSDescLo12, SymTLSDesc; + MO_TLSDESC_LO12.setTargetFlags(AArch64II::MO_TLS | AArch64II::MO_PAGEOFF); + MO_TLSDESC.setTargetFlags(AArch64II::MO_TLS | AArch64II::MO_PAGE); + MCInstLowering.lowerOperand(MO_TLSDESC_LO12, SymTLSDescLo12); + MCInstLowering.lowerOperand(MO_TLSDESC, SymTLSDesc); + + MCInst Adrp; + Adrp.setOpcode(AArch64::ADRP); + Adrp.addOperand(MCOperand::createReg(AArch64::X0)); + Adrp.addOperand(SymTLSDesc); + EmitToStreamer(*OutStreamer, Adrp); + + MCInst Ldr; + Ldr.setOpcode(AArch64::LDRXui); + Ldr.addOperand(MCOperand::createReg(AArch64::X16)); + Ldr.addOperand(MCOperand::createReg(AArch64::X0)); + Ldr.addOperand(SymTLSDescLo12); + Ldr.addOperand(MCOperand::createImm(0)); + EmitToStreamer(*OutStreamer, Ldr); + + MCInst Add; + Add.setOpcode(AArch64::ADDXri); + Add.addOperand(MCOperand::createReg(AArch64::X0)); + Add.addOperand(MCOperand::createReg(AArch64::X0)); + Add.addOperand(SymTLSDescLo12); + Add.addOperand(MCOperand::createImm(AArch64_AM::getShiftValue(0))); + EmitToStreamer(*OutStreamer, Add); + + // Authenticated TLSDESC accesses are not relaxed. + // Thus, do not emit .tlsdesccall for AUTH TLSDESC. + + MCInst Blraa; + Blraa.setOpcode(AArch64::BLRAA); + Blraa.addOperand(MCOperand::createReg(AArch64::X16)); + Blraa.addOperand(MCOperand::createReg(AArch64::X0)); + EmitToStreamer(*OutStreamer, Blraa); + + return; + } case AArch64::TLSDESC_CALLSEQ: { /// lower this to: /// adrp x0, :tlsdesc:var diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td index 1b1d81f..ce19806 100644 --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -131,6 +131,15 @@ def ext: GICombineRule < (apply [{ applyEXT(*${root}, ${matchinfo}); }]) >; +def fullrev: GICombineRule < + (defs root:$root, shuffle_matchdata:$matchinfo), + (match (G_IMPLICIT_DEF $src2), + (G_SHUFFLE_VECTOR $src, $src1, $src2, $mask):$root, + [{ return ShuffleVectorInst::isReverseMask(${mask}.getShuffleMask(), + ${mask}.getShuffleMask().size()); }]), + (apply [{ applyFullRev(*${root}, MRI); }]) +>; + def insertelt_nonconst: GICombineRule < (defs root:$root, shuffle_matchdata:$matchinfo), (match (wip_match_opcode G_INSERT_VECTOR_ELT):$root, @@ -163,7 +172,7 @@ def form_duplane : GICombineRule < (apply [{ applyDupLane(*${root}, MRI, B, ${matchinfo}); }]) >; -def shuffle_vector_lowering : GICombineGroup<[dup, rev, ext, zip, uzp, trn, +def shuffle_vector_lowering : GICombineGroup<[dup, rev, ext, zip, uzp, trn, fullrev, form_duplane, shuf_to_ins]>; // Turn G_UNMERGE_VALUES -> G_EXTRACT_VECTOR_ELT's diff --git a/llvm/lib/Target/AArch64/AArch64FMV.td b/llvm/lib/Target/AArch64/AArch64FMV.td index fc7a94a..e0f56fd 100644 --- a/llvm/lib/Target/AArch64/AArch64FMV.td +++ b/llvm/lib/Target/AArch64/AArch64FMV.td @@ -22,64 +22,65 @@ // Something you can add to target_version or target_clones. -class FMVExtension<string n, string b, int p> { +class FMVExtension<string name, string enumeration> { // Name, as spelled in target_version or target_clones. e.g. "memtag". - string Name = n; + string Name = name; // A C++ expression giving the number of the bit in the FMV ABI. // Currently this is given as a value from the enum "CPUFeatures". - string Bit = b; + string FeatureBit = "FEAT_" # enumeration; // SubtargetFeature enabled for codegen when this FMV feature is present. - string BackendFeature = n; + string BackendFeature = name; - // The FMV priority. - int Priority = p; + // A C++ expression giving the number of the priority bit. + // Currently this is given as a value from the enum "FeatPriorities". + string PriorityBit = "PRIOR_" # enumeration; } -def : FMVExtension<"aes", "FEAT_PMULL", 150>; -def : FMVExtension<"bf16", "FEAT_BF16", 280>; -def : FMVExtension<"bti", "FEAT_BTI", 510>; -def : FMVExtension<"crc", "FEAT_CRC", 110>; -def : FMVExtension<"dit", "FEAT_DIT", 180>; -def : FMVExtension<"dotprod", "FEAT_DOTPROD", 104>; -let BackendFeature = "ccpp" in def : FMVExtension<"dpb", "FEAT_DPB", 190>; -let BackendFeature = "ccdp" in def : FMVExtension<"dpb2", "FEAT_DPB2", 200>; -def : FMVExtension<"f32mm", "FEAT_SVE_F32MM", 350>; -def : FMVExtension<"f64mm", "FEAT_SVE_F64MM", 360>; -def : FMVExtension<"fcma", "FEAT_FCMA", 220>; -def : FMVExtension<"flagm", "FEAT_FLAGM", 20>; -let BackendFeature = "altnzcv" in def : FMVExtension<"flagm2", "FEAT_FLAGM2", 30>; -def : FMVExtension<"fp", "FEAT_FP", 90>; -def : FMVExtension<"fp16", "FEAT_FP16", 170>; -def : FMVExtension<"fp16fml", "FEAT_FP16FML", 175>; -let BackendFeature = "fptoint" in def : FMVExtension<"frintts", "FEAT_FRINTTS", 250>; -def : FMVExtension<"i8mm", "FEAT_I8MM", 270>; -def : FMVExtension<"jscvt", "FEAT_JSCVT", 210>; -def : FMVExtension<"ls64", "FEAT_LS64_ACCDATA", 520>; -def : FMVExtension<"lse", "FEAT_LSE", 80>; -def : FMVExtension<"memtag", "FEAT_MEMTAG2", 440>; -def : FMVExtension<"mops", "FEAT_MOPS", 650>; -def : FMVExtension<"predres", "FEAT_PREDRES", 480>; -def : FMVExtension<"rcpc", "FEAT_RCPC", 230>; -let BackendFeature = "rcpc-immo" in def : FMVExtension<"rcpc2", "FEAT_RCPC2", 240>; -def : FMVExtension<"rcpc3", "FEAT_RCPC3", 241>; -def : FMVExtension<"rdm", "FEAT_RDM", 108>; -def : FMVExtension<"rng", "FEAT_RNG", 10>; -def : FMVExtension<"sb", "FEAT_SB", 470>; -def : FMVExtension<"sha2", "FEAT_SHA2", 130>; -def : FMVExtension<"sha3", "FEAT_SHA3", 140>; -def : FMVExtension<"simd", "FEAT_SIMD", 100>; -def : FMVExtension<"sm4", "FEAT_SM4", 106>; -def : FMVExtension<"sme", "FEAT_SME", 430>; -def : FMVExtension<"sme-f64f64", "FEAT_SME_F64", 560>; -def : FMVExtension<"sme-i16i64", "FEAT_SME_I64", 570>; -def : FMVExtension<"sme2", "FEAT_SME2", 580>; -def : FMVExtension<"ssbs", "FEAT_SSBS2", 490>; -def : FMVExtension<"sve", "FEAT_SVE", 310>; -def : FMVExtension<"sve2", "FEAT_SVE2", 370>; -def : FMVExtension<"sve2-aes", "FEAT_SVE_PMULL128", 380>; -def : FMVExtension<"sve2-bitperm", "FEAT_SVE_BITPERM", 400>; -def : FMVExtension<"sve2-sha3", "FEAT_SVE_SHA3", 410>; -def : FMVExtension<"sve2-sm4", "FEAT_SVE_SM4", 420>; -def : FMVExtension<"wfxt", "FEAT_WFXT", 550>; +def : FMVExtension<"aes", "PMULL">; +def : FMVExtension<"bf16", "BF16">; +def : FMVExtension<"bti", "BTI">; +def : FMVExtension<"crc", "CRC">; +def : FMVExtension<"dit", "DIT">; +def : FMVExtension<"dotprod", "DOTPROD">; +let BackendFeature = "ccpp" in def : FMVExtension<"dpb", "DPB">; +let BackendFeature = "ccdp" in def : FMVExtension<"dpb2", "DPB2">; +def : FMVExtension<"f32mm", "SVE_F32MM">; +def : FMVExtension<"f64mm", "SVE_F64MM">; +def : FMVExtension<"fcma", "FCMA">; +def : FMVExtension<"flagm", "FLAGM">; +let BackendFeature = "altnzcv" in def : FMVExtension<"flagm2", "FLAGM2">; +def : FMVExtension<"fp", "FP">; +def : FMVExtension<"fp16", "FP16">; +def : FMVExtension<"fp16fml", "FP16FML">; +let BackendFeature = "fptoint" in def : FMVExtension<"frintts", "FRINTTS">; +def : FMVExtension<"i8mm", "I8MM">; +def : FMVExtension<"jscvt", "JSCVT">; +def : FMVExtension<"ls64", "LS64_ACCDATA">; +def : FMVExtension<"lse", "LSE">; +def : FMVExtension<"memtag", "MEMTAG2">; +def : FMVExtension<"mops", "MOPS">; +def : FMVExtension<"predres", "PREDRES">; +def : FMVExtension<"rcpc", "RCPC">; +let BackendFeature = "rcpc-immo" in def : FMVExtension<"rcpc2", "RCPC2">; +def : FMVExtension<"rcpc3", "RCPC3">; +def : FMVExtension<"rdm", "RDM">; +def : FMVExtension<"rng", "RNG">; +def : FMVExtension<"sb", "SB">; +def : FMVExtension<"sha2", "SHA2">; +def : FMVExtension<"sha3", "SHA3">; +def : FMVExtension<"simd", "SIMD">; +def : FMVExtension<"sm4", "SM4">; +def : FMVExtension<"sme", "SME">; +def : FMVExtension<"sme-f64f64", "SME_F64">; +def : FMVExtension<"sme-i16i64", "SME_I64">; +def : FMVExtension<"sme2", "SME2">; +def : FMVExtension<"ssbs", "SSBS2">; +def : FMVExtension<"sve", "SVE">; +def : FMVExtension<"sve2", "SVE2">; +def : FMVExtension<"sve2-aes", "SVE_PMULL128">; +def : FMVExtension<"sve2-bitperm", "SVE_BITPERM">; +def : FMVExtension<"sve2-sha3", "SVE_SHA3">; +def : FMVExtension<"sve2-sm4", "SVE_SM4">; +def : FMVExtension<"wfxt", "WFXT">; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a6f8f47..3ad2905 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -753,6 +753,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(Op, MVT::v8bf16, Expand); } + // For bf16, fpextend is custom lowered to be optionally expanded into shifts. + setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Custom); + auto LegalizeNarrowFP = [this](MVT ScalarVT) { for (auto Op : { ISD::SETCC, @@ -893,10 +901,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(Op, MVT::f16, Legal); } - // Strict conversion to a larger type is legal - for (auto VT : {MVT::f32, MVT::f64}) - setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal); - setOperationAction(ISD::PREFETCH, MVT::Other, Custom); setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom); @@ -1183,6 +1187,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setMaxDivRemBitWidthSupported(128); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom); + if (Subtarget->hasSME()) + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i1, Custom); if (Subtarget->isNeonAvailable()) { // FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to @@ -2669,6 +2675,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::CSINC) MAKE_CASE(AArch64ISD::THREAD_POINTER) MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ) + MAKE_CASE(AArch64ISD::TLSDESC_AUTH_CALLSEQ) MAKE_CASE(AArch64ISD::PROBED_ALLOCA) MAKE_CASE(AArch64ISD::ABDS_PRED) MAKE_CASE(AArch64ISD::ABDU_PRED) @@ -4495,6 +4502,54 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) return LowerFixedLengthFPExtendToSVE(Op, DAG); + bool IsStrict = Op->isStrictFPOpcode(); + SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::f64) { + // FP16->FP32 extends are legal for v32 and v4f32. + if (Op0VT == MVT::f32 || Op0VT == MVT::f16) + return Op; + // Split bf16->f64 extends into two fpextends. + if (Op0VT == MVT::bf16 && IsStrict) { + SDValue Ext1 = + DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {MVT::f32, MVT::Other}, + {Op0, Op.getOperand(0)}); + return DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {VT, MVT::Other}, + {Ext1, Ext1.getValue(1)}); + } + if (Op0VT == MVT::bf16) + return DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), VT, + DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Op0)); + return SDValue(); + } + + if (VT.getScalarType() == MVT::f32) { + // FP16->FP32 extends are legal for v32 and v4f32. + if (Op0VT.getScalarType() == MVT::f16) + return Op; + if (Op0VT.getScalarType() == MVT::bf16) { + SDLoc DL(Op); + EVT IVT = VT.changeTypeToInteger(); + if (!Op0VT.isVector()) { + Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4bf16, Op0); + IVT = MVT::v4i32; + } + + EVT Op0IVT = Op0.getValueType().changeTypeToInteger(); + SDValue Ext = + DAG.getNode(ISD::ANY_EXTEND, DL, IVT, DAG.getBitcast(Op0IVT, Op0)); + SDValue Shift = + DAG.getNode(ISD::SHL, DL, IVT, Ext, DAG.getConstant(16, DL, IVT)); + if (!Op0VT.isVector()) + Shift = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, Shift, + DAG.getConstant(0, DL, MVT::i64)); + Shift = DAG.getBitcast(VT, Shift); + return IsStrict ? DAG.getMergeValues({Shift, Op.getOperand(0)}, DL) + : Shift; + } + return SDValue(); + } + assert(Op.getValueType() == MVT::f128 && "Unexpected lowering"); return SDValue(); } @@ -7342,6 +7397,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::STRICT_FP_ROUND: return LowerFP_ROUND(Op, DAG); case ISD::FP_EXTEND: + case ISD::STRICT_FP_EXTEND: return LowerFP_EXTEND(Op, DAG); case ISD::FRAMEADDR: return LowerFRAMEADDR(Op, DAG); @@ -10123,8 +10179,11 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr, SDValue Chain = DAG.getEntryNode(); SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue); - Chain = - DAG.getNode(AArch64ISD::TLSDESC_CALLSEQ, DL, NodeTys, {Chain, SymAddr}); + unsigned Opcode = + DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT() + ? AArch64ISD::TLSDESC_AUTH_CALLSEQ + : AArch64ISD::TLSDESC_CALLSEQ; + Chain = DAG.getNode(Opcode, DL, NodeTys, {Chain, SymAddr}); SDValue Glue = Chain.getValue(1); return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue); @@ -10136,8 +10195,12 @@ AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op, assert(Subtarget->isTargetELF() && "This function expects an ELF target"); const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op); + AArch64FunctionInfo *MFI = + DAG.getMachineFunction().getInfo<AArch64FunctionInfo>(); - TLSModel::Model Model = getTargetMachine().getTLSModel(GA->getGlobal()); + TLSModel::Model Model = MFI->hasELFSignedGOT() + ? TLSModel::GeneralDynamic + : getTargetMachine().getTLSModel(GA->getGlobal()); if (!EnableAArch64ELFLocalDynamicTLSGeneration) { if (Model == TLSModel::LocalDynamic) @@ -10174,8 +10237,6 @@ AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op, // calculation. // These accesses will need deduplicating if there's more than one. - AArch64FunctionInfo *MFI = - DAG.getMachineFunction().getInfo<AArch64FunctionInfo>(); MFI->incNumLocalDynamicTLSAccesses(); // The call needs a relocation too for linker relaxation. It doesn't make @@ -18424,7 +18485,7 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) { EVT VT = A.getValueType(); SDValue Op0 = A.getOperand(0); SDValue Op1 = A.getOperand(1); - if (Op0.getOpcode() != Op0.getOpcode() || + if (Op0.getOpcode() != Op1.getOpcode() || (Op0.getOpcode() != ISD::ZERO_EXTEND && Op0.getOpcode() != ISD::SIGN_EXTEND)) return SDValue(); @@ -21981,21 +22042,35 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, SDLoc DL(N); SDValue Op2 = N->getOperand(2); - if (Op2->getOpcode() != ISD::MUL || - !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) || - !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode())) - return SDValue(); + unsigned Op2Opcode = Op2->getOpcode(); + SDValue MulOpLHS, MulOpRHS; + bool MulOpLHSIsSigned, MulOpRHSIsSigned; + if (ISD::isExtOpcode(Op2Opcode)) { + MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND); + MulOpLHS = Op2->getOperand(0); + MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType()); + } else if (Op2Opcode == ISD::MUL) { + SDValue ExtMulOpLHS = Op2->getOperand(0); + SDValue ExtMulOpRHS = Op2->getOperand(1); + + unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode(); + unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode(); + if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) || + !ISD::isExtOpcode(ExtMulOpRHSOpcode)) + return SDValue(); - SDValue Acc = N->getOperand(1); - SDValue Mul = N->getOperand(2); - SDValue ExtMulOpLHS = Mul->getOperand(0); - SDValue ExtMulOpRHS = Mul->getOperand(1); + MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND; + MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND; + + MulOpLHS = ExtMulOpLHS->getOperand(0); + MulOpRHS = ExtMulOpRHS->getOperand(0); - SDValue MulOpLHS = ExtMulOpLHS->getOperand(0); - SDValue MulOpRHS = ExtMulOpRHS->getOperand(0); - if (MulOpLHS.getValueType() != MulOpRHS.getValueType()) + if (MulOpLHS.getValueType() != MulOpRHS.getValueType()) + return SDValue(); + } else return SDValue(); + SDValue Acc = N->getOperand(1); EVT ReducedVT = N->getValueType(0); EVT MulSrcVT = MulOpLHS.getValueType(); @@ -22009,8 +22084,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) return SDValue(); - bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND; - bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND; // If the extensions are mixed, we should lower it to a usdot instead unsigned Opcode = 0; if (MulOpLHSIsSigned != MulOpRHSIsSigned) { @@ -22026,10 +22099,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, // USDOT expects the signed operand to be last if (!MulOpRHSIsSigned) std::swap(MulOpLHS, MulOpRHS); - } else if (MulOpLHSIsSigned) - Opcode = AArch64ISD::SDOT; - else - Opcode = AArch64ISD::UDOT; + } else + Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT; // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot // product followed by a zero / sign extension @@ -27413,6 +27484,15 @@ void AArch64TargetLowering::ReplaceNodeResults( Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V)); return; } + case Intrinsic::aarch64_sme_in_streaming_mode: { + SDLoc DL(N); + SDValue Chain = DAG.getEntryNode(); + SDValue RuntimePStateSM = + getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0)); + Results.push_back( + DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, RuntimePStateSM)); + return; + } case Intrinsic::experimental_vector_match: case Intrinsic::get_active_lane_mask: { if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1) @@ -29648,9 +29728,16 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported( if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2() && VTy->isScalableTy()) { unsigned ScalarWidth = ScalarTy->getScalarSizeInBits(); + + if (Operation == ComplexDeinterleavingOperation::CDot) + return ScalarWidth == 32 || ScalarWidth == 64; return 8 <= ScalarWidth && ScalarWidth <= 64; } + // CDot is not supported outside of scalable/sve scopes + if (Operation == ComplexDeinterleavingOperation::CDot) + return false; + return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) || ScalarTy->isFloatTy() || ScalarTy->isDoubleTy(); } @@ -29660,6 +29747,8 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR( ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator) const { VectorType *Ty = cast<VectorType>(InputA->getType()); + if (Accumulator == nullptr) + Accumulator = Constant::getNullValue(Ty); bool IsScalable = Ty->isScalableTy(); bool IsInt = Ty->getElementType()->isIntegerTy(); @@ -29671,6 +29760,10 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR( if (TyWidth > 128) { int Stride = Ty->getElementCount().getKnownMinValue() / 2; + int AccStride = cast<VectorType>(Accumulator->getType()) + ->getElementCount() + .getKnownMinValue() / + 2; auto *HalfTy = VectorType::getHalfElementsVectorType(Ty); auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0)); auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0)); @@ -29680,25 +29773,26 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR( B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride)); Value *LowerSplitAcc = nullptr; Value *UpperSplitAcc = nullptr; - if (Accumulator) { - LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0)); - UpperSplitAcc = - B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride)); - } + Type *FullTy = Ty; + FullTy = Accumulator->getType(); + auto *HalfAccTy = VectorType::getHalfElementsVectorType( + cast<VectorType>(Accumulator->getType())); + LowerSplitAcc = + B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0)); + UpperSplitAcc = + B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride)); auto *LowerSplitInt = createComplexDeinterleavingIR( B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc); auto *UpperSplitInt = createComplexDeinterleavingIR( B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc); - auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt, - B.getInt64(0)); - return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride)); + auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy), + LowerSplitInt, B.getInt64(0)); + return B.CreateInsertVector(FullTy, Result, UpperSplitInt, + B.getInt64(AccStride)); } if (OperationType == ComplexDeinterleavingOperation::CMulPartial) { - if (Accumulator == nullptr) - Accumulator = Constant::getNullValue(Ty); - if (IsScalable) { if (IsInt) return B.CreateIntrinsic( @@ -29750,6 +29844,13 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR( return B.CreateIntrinsic(IntId, Ty, {InputA, InputB}); } + if (OperationType == ComplexDeinterleavingOperation::CDot && IsInt && + IsScalable) { + return B.CreateIntrinsic( + Intrinsic::aarch64_sve_cdot, Accumulator->getType(), + {Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)}); + } + return nullptr; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 1b7f328..85b62be 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -83,6 +83,7 @@ enum NodeType : unsigned { // Produces the full sequence of instructions for getting the thread pointer // offset of a variable into X0, using the TLSDesc model. TLSDESC_CALLSEQ, + TLSDESC_AUTH_CALLSEQ, ADRP, // Page address of a TargetGlobalAddress operand. ADR, // ADR ADDlow, // Add the low 12 bits of a TargetGlobalAddress operand. diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 47c4c6c..f527f7e 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -1804,7 +1804,9 @@ class TMSystemException<bits<3> op1, string asm, list<dag> pattern> } class APASI : SimpleSystemI<0, (ins GPR64:$Xt), "apas", "\t$Xt">, Sched<[]> { + bits<5> Xt; let Inst{20-5} = 0b0111001110000000; + let Inst{4-0} = Xt; let DecoderNamespace = "APAS"; } @@ -2768,6 +2770,8 @@ class MulHi<bits<3> opc, string asm, SDNode OpNode> let Inst{23-21} = opc; let Inst{20-16} = Rm; let Inst{15} = 0; + let Inst{14-10} = 0b11111; + let Unpredictable{14-10} = 0b11111; let Inst{9-5} = Rn; let Inst{4-0} = Rd; @@ -4920,6 +4924,8 @@ class LoadExclusivePair<bits<2> sz, bit o2, bit L, bit o1, bit o0, bits<5> Rt; bits<5> Rt2; bits<5> Rn; + let Inst{20-16} = 0b11111; + let Unpredictable{20-16} = 0b11111; let Inst{14-10} = Rt2; let Inst{9-5} = Rn; let Inst{4-0} = Rt; @@ -4935,6 +4941,7 @@ class BaseLoadStoreExclusiveLSUI<bits<2> sz, bit L, bit o0, let Inst{31-30} = sz; let Inst{29-23} = 0b0010010; let Inst{22} = L; + let Inst{21} = 0b0; let Inst{15} = o0; } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 629098c..c6f5cdc 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -883,6 +883,9 @@ def AArch64tlsdesc_callseq : SDNode<"AArch64ISD::TLSDESC_CALLSEQ", SDT_AArch64TLSDescCallSeq, [SDNPOutGlue, SDNPHasChain, SDNPVariadic]>; +def AArch64tlsdesc_auth_callseq : SDNode<"AArch64ISD::TLSDESC_AUTH_CALLSEQ", + SDT_AArch64TLSDescCallSeq, + [SDNPOutGlue, SDNPHasChain, SDNPVariadic]>; def AArch64WrapperLarge : SDNode<"AArch64ISD::WrapperLarge", SDT_AArch64WrapperLarge>; @@ -3312,8 +3315,16 @@ def TLSDESC_CALLSEQ : Pseudo<(outs), (ins i64imm:$sym), [(AArch64tlsdesc_callseq tglobaltlsaddr:$sym)]>, Sched<[WriteI, WriteLD, WriteI, WriteBrReg]>; +let isCall = 1, Defs = [NZCV, LR, X0, X16], hasSideEffects = 1, Size = 16, + isCodeGenOnly = 1 in +def TLSDESC_AUTH_CALLSEQ + : Pseudo<(outs), (ins i64imm:$sym), + [(AArch64tlsdesc_auth_callseq tglobaltlsaddr:$sym)]>, + Sched<[WriteI, WriteLD, WriteI, WriteBrReg]>; def : Pat<(AArch64tlsdesc_callseq texternalsym:$sym), (TLSDESC_CALLSEQ texternalsym:$sym)>; +def : Pat<(AArch64tlsdesc_auth_callseq texternalsym:$sym), + (TLSDESC_AUTH_CALLSEQ texternalsym:$sym)>; //===----------------------------------------------------------------------===// // Conditional branch (immediate) instruction. @@ -5112,22 +5123,6 @@ let Predicates = [HasFullFP16] in { //===----------------------------------------------------------------------===// defm FCVT : FPConversion<"fcvt">; -// Helper to get bf16 into fp32. -def cvt_bf16_to_fp32 : - OutPatFrag<(ops node:$Rn), - (f32 (COPY_TO_REGCLASS - (i32 (UBFMWri - (i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)), - node:$Rn, hsub), GPR32)), - (i64 (i32shift_a (i64 16))), - (i64 (i32shift_b (i64 16))))), - FPR32))>; -// Pattern for bf16 -> fp32. -def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))), - (cvt_bf16_to_fp32 FPR16:$Rn)>; -// Pattern for bf16 -> fp64. -def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))), - (FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>; //===----------------------------------------------------------------------===// // Floating point single operand instructions. @@ -8322,8 +8317,6 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))> def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>; def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>; def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>; -// Vector bf16 -> fp32 is implemented morally as a zext + shift. -def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))), (SHLLv4i16 V64:$Rn)>; // Also match an extend from the upper half of a 128 bit source register. def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))), (USHLLv16i8_shift V128:$Rn, (i32 0))>; diff --git a/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp b/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp index a290a51..c3bc70a 100644 --- a/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp +++ b/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp @@ -144,20 +144,20 @@ void AArch64PointerAuth::signLR(MachineFunction &MF, // No SEH opcode for this one; it doesn't materialize into an // instruction on Windows. if (MFnI.branchProtectionPAuthLR() && Subtarget->hasPAuthLR()) { + emitPACCFI(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameSetup, EmitCFI); BuildMI(MBB, MBBI, DL, TII->get(MFnI.shouldSignWithBKey() ? AArch64::PACIBSPPC : AArch64::PACIASPPC)) .setMIFlag(MachineInstr::FrameSetup) ->setPreInstrSymbol(MF, MFnI.getSigningInstrLabel()); - emitPACCFI(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameSetup, EmitCFI); } else { BuildPACM(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameSetup); + emitPACCFI(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameSetup, EmitCFI); BuildMI(MBB, MBBI, DL, TII->get(MFnI.shouldSignWithBKey() ? AArch64::PACIBSP : AArch64::PACIASP)) .setMIFlag(MachineInstr::FrameSetup) ->setPreInstrSymbol(MF, MFnI.getSigningInstrLabel()); - emitPACCFI(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameSetup, EmitCFI); } if (!EmitCFI && NeedsWinCFI) { @@ -212,19 +212,19 @@ void AArch64PointerAuth::authenticateLR( if (MFnI->branchProtectionPAuthLR() && Subtarget->hasPAuthLR()) { assert(PACSym && "No PAC instruction to refer to"); emitPACSymOffsetIntoX16(*TII, MBB, MBBI, DL, PACSym); + emitPACCFI(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameDestroy, + EmitAsyncCFI); BuildMI(MBB, MBBI, DL, TII->get(UseBKey ? AArch64::AUTIBSPPCi : AArch64::AUTIASPPCi)) .addSym(PACSym) .setMIFlag(MachineInstr::FrameDestroy); - emitPACCFI(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameDestroy, - EmitAsyncCFI); } else { BuildPACM(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameDestroy, PACSym); + emitPACCFI(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameDestroy, + EmitAsyncCFI); BuildMI(MBB, MBBI, DL, TII->get(UseBKey ? AArch64::AUTIBSP : AArch64::AUTIASP)) .setMIFlag(MachineInstr::FrameDestroy); - emitPACCFI(*Subtarget, MBB, MBBI, DL, MachineInstr::FrameDestroy, - EmitAsyncCFI); } if (NeedsWinCFI) { diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 8b8d73d..aee54ed 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -979,8 +979,7 @@ defm FSCALE_2ZZ : sme2_fp_sve_destructive_vector_vg2_single<"fscale", 0b001100 defm FSCALE_4ZZ : sme2_fp_sve_destructive_vector_vg4_single<"fscale", 0b0011000>; defm FSCALE_2Z2Z : sme2_fp_sve_destructive_vector_vg2_multi<"fscale", 0b0011000>; defm FSCALE_4Z4Z : sme2_fp_sve_destructive_vector_vg4_multi<"fscale", 0b0011000>; - -} // [HasSME2, HasFP8] +} let Predicates = [HasSME2, HasFAMINMAX] in { defm FAMAX_2Z2Z : sme2_fp_sve_destructive_vector_vg2_multi<"famax", 0b0010100>; @@ -988,17 +987,16 @@ defm FAMIN_2Z2Z : sme2_fp_sve_destructive_vector_vg2_multi<"famin", 0b0010101>; defm FAMAX_4Z4Z : sme2_fp_sve_destructive_vector_vg4_multi<"famax", 0b0010100>; defm FAMIN_4Z4Z : sme2_fp_sve_destructive_vector_vg4_multi<"famin", 0b0010101>; -} //[HasSME2, HasFAMINMAX] - +} let Predicates = [HasSME_LUTv2] in { defm MOVT_TIZ : sme2_movt_zt_to_zt<"movt", 0b0011111, int_aarch64_sme_write_lane_zt, int_aarch64_sme_write_zt>; def LUTI4_4ZZT2Z : sme2_luti4_vector_vg4<0b00, 0b00,"luti4">; -} //[HasSME_LUTv2] +} let Predicates = [HasSME2p1, HasSME_LUTv2] in { def LUTI4_S_4ZZT2Z : sme2_luti4_vector_vg4_strided<0b00, 0b00, "luti4">; -} //[HasSME2p1, HasSME_LUTv2] +} let Predicates = [HasSMEF8F16] in { defm FVDOT_VG2_M2ZZI_BtoH : sme2_fp8_fdot_index_za16_vg1x2<"fvdot", 0b110, int_aarch64_sme_fp8_fvdot_lane_za16_vg1x2>; @@ -1014,17 +1012,15 @@ defm FMLAL_MZZI_BtoH : sme2_fp8_fmlal_index_za16<"fmlal", int_aarch64_ defm FMLAL_VG2_M2ZZI_BtoH : sme2_fp8_fmlal_index_za16_vgx2<"fmlal", int_aarch64_sme_fp8_fmlal_lane_za16_vg2x2>; defm FMLAL_VG4_M4ZZI_BtoH : sme2_fp8_fmlal_index_za16_vgx4<"fmlal", int_aarch64_sme_fp8_fmlal_lane_za16_vg2x4>; -// FP8 FMLAL (single) defm FMLAL_VG2_MZZ_BtoH : sme2_fp8_fmlal_single_za16<"fmlal", int_aarch64_sme_fp8_fmlal_single_za16_vg2x1>; -defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlal_single_za16_vg2x2, [FPMR, FPCR]>; +defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlal_single_za16_vg2x2, [FPMR, FPCR]>; defm FMLAL_VG4_M4ZZ_BtoH : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlal_single_za16_vg2x4, [FPMR, FPCR]>; -// FP8 FMLALL (multi) defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlal_multi_za16_vg2x2, [FPMR, FPCR]>; defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlal_multi_za16_vg2x4, [FPMR, FPCR]>; defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>; -} //[HasSMEF8F16] +} let Predicates = [HasSMEF8F32] in { defm FDOT_VG2_M2ZZI_BtoS : sme2_fp8_fdot_index_za32_vg1x2<"fdot", int_aarch64_sme_fp8_fdot_lane_za32_vg1x2>; @@ -1042,17 +1038,15 @@ defm FMLALL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"fmlall", 0b01, 0b0 defm FMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"fmlall", 0b10, 0b100, int_aarch64_sme_fp8_fmlall_lane_za32_vg4x2, [FPMR, FPCR]>; defm FMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"fmlall", 0b00, 0b1000, int_aarch64_sme_fp8_fmlall_lane_za32_vg4x4, [FPMR, FPCR]>; -// FP8 FMLALL (single) defm FMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"fmlall", 0b01000, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x1, [FPMR, FPCR]>; defm FMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg2_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x2, [FPMR, FPCR]>; defm FMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg4_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x4, [FPMR, FPCR]>; -// FP8 FMLALL (multi) defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlall_multi_za32_vg4x2, [FPMR, FPCR]>; defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlall_multi_za32_vg4x4, [FPMR, FPCR]>; defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>; -} //[HasSMEF8F32] +} let Predicates = [HasSME2, HasSVEBFSCALE] in { defm BFSCALE : sme2_bfscale_single<"bfscale">; @@ -1077,31 +1071,31 @@ let Predicates = [HasSME2p2] in { defm FMOP4A : sme2_fmop4as_fp16_fp32_widening<0, "fmop4a">; defm FMOP4S : sme2_fmop4as_fp16_fp32_widening<1, "fmop4s">; -} // [HasSME2p2] +} let Predicates = [HasSME2p2, HasSMEB16B16] in { def BFTMOPA_M2ZZZI_HtoH : sme_tmopa_16b<0b11001, ZZ_h_mul_r, ZPR16, "bftmopa">; -} // [HasSME2p2, HasSMEB16B16] +} let Predicates = [HasSME2p2, HasSMEF8F32], Uses = [FPMR, FPCR] in { def FTMOPA_M2ZZZI_BtoS : sme_tmopa_32b<0b01000, ZZ_b_mul_r, ZPR8, "ftmopa">; -} // [HasSME2p2, HasSMEF8F32], Uses = [FPMR, FPCR] +} let Predicates = [HasSME2p2, HasSMEF8F16], Uses = [FPMR, FPCR] in { def FTMOPA_M2ZZZI_BtoH : sme_tmopa_16b<0b01001, ZZ_b_mul_r, ZPR8, "ftmopa">; defm FMOP4A : sme2_fmop4a_fp8_fp16_2way<"fmop4a">; -} // [HasSME2p2, HasSMEF8F16], Uses = [FPMR, FPCR] +} let Predicates = [HasSME2p2, HasSMEF16F16] in { def FTMOPA_M2ZZZI_HtoH : sme_tmopa_16b<0b10001, ZZ_h_mul_r, ZPR16, "ftmopa">; defm FMOP4A : sme2_fmop4as_fp16_non_widening<0, "fmop4a">; defm FMOP4S : sme2_fmop4as_fp16_non_widening<1, "fmop4s">; -} // [HasSME2p2, HasSMEF16F16] +} let Predicates = [HasSME2, HasSVEBFSCALE] in { defm BFMUL : sme2_bfmul_single<"bfmul">; defm BFMUL : sme2_bfmul_multi<"bfmul">; -} //[HasSME2, HasSVEBFSCALE] +} let Uses = [FPMR, FPCR] in { let Predicates = [HasSME2p2, HasSMEF8F32] in { diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td index 737fc73..e23daec 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td @@ -512,6 +512,12 @@ def N2Write_8c_3L_4V : SchedWriteRes<[N2UnitL, N2UnitL, N2UnitL, let NumMicroOps = 7; } +def N2Write_7c_7V0 : SchedWriteRes<[N2UnitV0]> { + let Latency = 7; + let NumMicroOps = 7; + let ReleaseAtCycles = [7]; +} + //===----------------------------------------------------------------------===// // Define generic 8 micro-op types @@ -548,6 +554,15 @@ def N2Write_9c_4L_4V : SchedWriteRes<[N2UnitL, N2UnitL, N2UnitL, N2UnitL, } //===----------------------------------------------------------------------===// +// Define generic 9 micro-op types + +def N2Write_9c_9V0 : SchedWriteRes<[N2UnitV0]> { + let Latency = 9; + let NumMicroOps = 9; + let ReleaseAtCycles = [9]; +} + +//===----------------------------------------------------------------------===// // Define generic 10 micro-op types def N2Write_7c_5L01_5V : SchedWriteRes<[N2UnitL01, N2UnitL01, N2UnitL01, @@ -557,6 +572,12 @@ def N2Write_7c_5L01_5V : SchedWriteRes<[N2UnitL01, N2UnitL01, N2UnitL01, let NumMicroOps = 10; } +def N2Write_10c_10V0 : SchedWriteRes<[N2UnitV0]> { + let Latency = 10; + let NumMicroOps = 10; + let ReleaseAtCycles = [10]; +} + //===----------------------------------------------------------------------===// // Define generic 12 micro-op types @@ -580,6 +601,21 @@ def N2Write_7c_5L01_5S_5V : SchedWriteRes<[N2UnitL01, N2UnitL01, N2UnitL01, let NumMicroOps = 15; } +def N2Write_15c_15V0 : SchedWriteRes<[N2UnitV0]> { + let Latency = 15; + let NumMicroOps = 15; + let ReleaseAtCycles = [15]; +} + +//===----------------------------------------------------------------------===// +// Define generic 16 micro-op types + +def N2Write_16c_16V0 : SchedWriteRes<[N2UnitV0]> { + let Latency = 16; + let NumMicroOps = 16; + let ReleaseAtCycles = [16]; +} + //===----------------------------------------------------------------------===// // Define generic 18 micro-op types @@ -795,22 +831,26 @@ def : SchedAlias<WriteF, N2Write_2c_1V>; // FP compare def : SchedAlias<WriteFCmp, N2Write_2c_1V0>; +// FP divide and square root operations are performed using an iterative +// algorithm and block subsequent similar operations to the same pipeline +// until complete (Arm Neoverse N2 Software Optimization Guide, 3.14). + // FP divide, square root -def : SchedAlias<WriteFDiv, N2Write_7c_1V0>; +def : SchedAlias<WriteFDiv, N2Write_7c_7V0>; // FP divide, H-form -def : InstRW<[N2Write_7c_1V0], (instrs FDIVHrr)>; +def : InstRW<[N2Write_7c_7V0], (instrs FDIVHrr)>; // FP divide, S-form -def : InstRW<[N2Write_10c_1V0], (instrs FDIVSrr)>; +def : InstRW<[N2Write_10c_10V0], (instrs FDIVSrr)>; // FP divide, D-form -def : InstRW<[N2Write_15c_1V0], (instrs FDIVDrr)>; +def : InstRW<[N2Write_15c_15V0], (instrs FDIVDrr)>; // FP square root, H-form -def : InstRW<[N2Write_7c_1V0], (instrs FSQRTHr)>; +def : InstRW<[N2Write_7c_7V0], (instrs FSQRTHr)>; // FP square root, S-form -def : InstRW<[N2Write_9c_1V0], (instrs FSQRTSr)>; +def : InstRW<[N2Write_9c_9V0], (instrs FSQRTSr)>; // FP square root, D-form -def : InstRW<[N2Write_16c_1V0], (instrs FSQRTDr)>; +def : InstRW<[N2Write_16c_16V0], (instrs FSQRTDr)>; // FP multiply def : WriteRes<WriteFMul, [N2UnitV]> { let Latency = 3; } diff --git a/llvm/lib/Target/AArch64/AArch64SystemOperands.td b/llvm/lib/Target/AArch64/AArch64SystemOperands.td index f22e024..355a9d2 100644 --- a/llvm/lib/Target/AArch64/AArch64SystemOperands.td +++ b/llvm/lib/Target/AArch64/AArch64SystemOperands.td @@ -42,10 +42,7 @@ def HasCONTEXTIDREL2 //===----------------------------------------------------------------------===// class AT<string name, bits<3> op1, bits<4> crn, bits<4> crm, - bits<3> op2> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - + bits<3> op2> { string Name = name; bits<14> Encoding; let Encoding{13-11} = op1; @@ -55,6 +52,27 @@ class AT<string name, bits<3> op1, bits<4> crn, bits<4> crm, code Requires = [{ {} }]; } +def ATValues : GenericEnum { + let FilterClass = "AT"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def ATsList : GenericTable { + let FilterClass = "AT"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupATByName : SearchIndex { + let Table = ATsList; + let Key = ["Name"]; +} + +def lookupATByEncoding : SearchIndex { + let Table = ATsList; + let Key = ["Encoding"]; +} + def : AT<"S1E1R", 0b000, 0b0111, 0b1000, 0b000>; def : AT<"S1E2R", 0b100, 0b0111, 0b1000, 0b000>; def : AT<"S1E3R", 0b110, 0b0111, 0b1000, 0b000>; @@ -82,14 +100,32 @@ def : AT<"S1E3A", 0b110, 0b0111, 0b1001, 0b010>; // DMB/DSB (data barrier) instruction options. //===----------------------------------------------------------------------===// -class DB<string name, bits<4> encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class DB<string name, bits<4> encoding> { string Name = name; bits<4> Encoding = encoding; } +def DBValues : GenericEnum { + let FilterClass = "DB"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def DBsList : GenericTable { + let FilterClass = "DB"; + let Fields = ["Name", "Encoding"]; +} + +def lookupDBByName : SearchIndex { + let Table = DBsList; + let Key = ["Name"]; +} + +def lookupDBByEncoding : SearchIndex { + let Table = DBsList; + let Key = ["Encoding"]; +} + def : DB<"oshld", 0x1>; def : DB<"oshst", 0x2>; def : DB<"osh", 0x3>; @@ -103,16 +139,39 @@ def : DB<"ld", 0xd>; def : DB<"st", 0xe>; def : DB<"sy", 0xf>; -class DBnXS<string name, bits<4> encoding, bits<5> immValue> : SearchableTable { - let SearchableFields = ["Name", "Encoding", "ImmValue"]; - let EnumValueField = "Encoding"; - +class DBnXS<string name, bits<4> encoding, bits<5> immValue> { string Name = name; bits<4> Encoding = encoding; bits<5> ImmValue = immValue; code Requires = [{ {AArch64::FeatureXS} }]; } +def DBnXSValues : GenericEnum { + let FilterClass = "DBnXS"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def DBnXSsList : GenericTable { + let FilterClass = "DBnXS"; + let Fields = ["Name", "Encoding", "ImmValue", "Requires"]; +} + +def lookupDBnXSByName : SearchIndex { + let Table = DBnXSsList; + let Key = ["Name"]; +} + +def lookupDBnXSByEncoding : SearchIndex { + let Table = DBnXSsList; + let Key = ["Encoding"]; +} + +def lookupDBnXSByImmValue : SearchIndex { + let Table = DBnXSsList; + let Key = ["ImmValue"]; +} + def : DBnXS<"oshnxs", 0x3, 0x10>; def : DBnXS<"nshnxs", 0x7, 0x14>; def : DBnXS<"ishnxs", 0xb, 0x18>; @@ -123,10 +182,7 @@ def : DBnXS<"synxs", 0xf, 0x1c>; //===----------------------------------------------------------------------===// class DC<string name, bits<3> op1, bits<4> crn, bits<4> crm, - bits<3> op2> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - + bits<3> op2> { string Name = name; bits<14> Encoding; let Encoding{13-11} = op1; @@ -136,6 +192,27 @@ class DC<string name, bits<3> op1, bits<4> crn, bits<4> crm, code Requires = [{ {} }]; } +def DCValues : GenericEnum { + let FilterClass = "DC"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def DCsList : GenericTable { + let FilterClass = "DC"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupDCByName : SearchIndex { + let Table = DCsList; + let Key = ["Name"]; +} + +def lookupDCByEncoding : SearchIndex { + let Table = DCsList; + let Key = ["Encoding"]; +} + def : DC<"ZVA", 0b011, 0b0111, 0b0100, 0b001>; def : DC<"IVAC", 0b000, 0b0111, 0b0110, 0b001>; def : DC<"ISW", 0b000, 0b0111, 0b0110, 0b010>; @@ -193,10 +270,7 @@ def : DC<"CGDVAOC", 0b011, 0b0111, 0b1011, 0b111>; //===----------------------------------------------------------------------===// class IC<string name, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2, - bit needsreg> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - + bit needsreg> { string Name = name; bits<14> Encoding; let Encoding{13-11} = op1; @@ -206,6 +280,27 @@ class IC<string name, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2, bit NeedsReg = needsreg; } +def ICValues : GenericEnum { + let FilterClass = "IC"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def ICsList : GenericTable { + let FilterClass = "IC"; + let Fields = ["Name", "Encoding", "NeedsReg"]; +} + +def lookupICByName : SearchIndex { + let Table = ICsList; + let Key = ["Name"]; +} + +def lookupICByEncoding : SearchIndex { + let Table = ICsList; + let Key = ["Encoding"]; +} + def : IC<"IALLUIS", 0b000, 0b0111, 0b0001, 0b000, 0>; def : IC<"IALLU", 0b000, 0b0111, 0b0101, 0b000, 0>; def : IC<"IVAU", 0b011, 0b0111, 0b0101, 0b001, 1>; @@ -214,25 +309,40 @@ def : IC<"IVAU", 0b011, 0b0111, 0b0101, 0b001, 1>; // ISB (instruction-fetch barrier) instruction options. //===----------------------------------------------------------------------===// -class ISB<string name, bits<4> encoding> : SearchableTable{ - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class ISB<string name, bits<4> encoding> { string Name = name; bits<4> Encoding; let Encoding = encoding; } +def ISBValues : GenericEnum { + let FilterClass = "ISB"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def ISBsList : GenericTable { + let FilterClass = "ISB"; + let Fields = ["Name", "Encoding"]; +} + +def lookupISBByName : SearchIndex { + let Table = ISBsList; + let Key = ["Name"]; +} + +def lookupISBByEncoding : SearchIndex { + let Table = ISBsList; + let Key = ["Encoding"]; +} + def : ISB<"sy", 0xf>; //===----------------------------------------------------------------------===// // TSB (Trace synchronization barrier) instruction options. //===----------------------------------------------------------------------===// -class TSB<string name, bits<4> encoding> : SearchableTable{ - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class TSB<string name, bits<4> encoding> { string Name = name; bits<4> Encoding; let Encoding = encoding; @@ -240,6 +350,27 @@ class TSB<string name, bits<4> encoding> : SearchableTable{ code Requires = [{ {AArch64::FeatureTRACEV8_4} }]; } +def TSBValues : GenericEnum { + let FilterClass = "TSB"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def TSBsList : GenericTable { + let FilterClass = "TSB"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupTSBByName : SearchIndex { + let Table = TSBsList; + let Key = ["Name"]; +} + +def lookupTSBByEncoding : SearchIndex { + let Table = TSBsList; + let Key = ["Encoding"]; +} + def : TSB<"csync", 0>; //===----------------------------------------------------------------------===// @@ -248,10 +379,7 @@ def : TSB<"csync", 0>; class PRFM<string type, bits<2> type_encoding, string target, bits<2> target_encoding, - string policy, bits<1> policy_encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - + string policy, bits<1> policy_encoding> { string Name = type # target # policy; bits<5> Encoding; let Encoding{4-3} = type_encoding; @@ -261,6 +389,27 @@ class PRFM<string type, bits<2> type_encoding, code Requires = [{ {} }]; } +def PRFMValues : GenericEnum { + let FilterClass = "PRFM"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def PRFMsList : GenericTable { + let FilterClass = "PRFM"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupPRFMByName : SearchIndex { + let Table = PRFMsList; + let Key = ["Name"]; +} + +def lookupPRFMByEncoding : SearchIndex { + let Table = PRFMsList; + let Key = ["Encoding"]; +} + def : PRFM<"pld", 0b00, "l1", 0b00, "keep", 0b0>; def : PRFM<"pld", 0b00, "l1", 0b00, "strm", 0b1>; def : PRFM<"pld", 0b00, "l2", 0b01, "keep", 0b0>; @@ -296,16 +445,34 @@ def : PRFM<"pst", 0b10, "slc", 0b11, "strm", 0b1>; // SVE Prefetch instruction options. //===----------------------------------------------------------------------===// -class SVEPRFM<string name, bits<4> encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class SVEPRFM<string name, bits<4> encoding> { string Name = name; bits<4> Encoding; let Encoding = encoding; code Requires = [{ {} }]; } +def SVEPRFMValues : GenericEnum { + let FilterClass = "SVEPRFM"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def SVEPRFMsList : GenericTable { + let FilterClass = "SVEPRFM"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupSVEPRFMByName : SearchIndex { + let Table = SVEPRFMsList; + let Key = ["Name"]; +} + +def lookupSVEPRFMByEncoding : SearchIndex { + let Table = SVEPRFMsList; + let Key = ["Encoding"]; +} + let Requires = [{ {AArch64::FeatureSVE} }] in { def : SVEPRFM<"pldl1keep", 0x00>; def : SVEPRFM<"pldl1strm", 0x01>; @@ -325,10 +492,7 @@ def : SVEPRFM<"pstl3strm", 0x0d>; // RPRFM (prefetch) instruction options. //===----------------------------------------------------------------------===// -class RPRFM<string name, bits<1> type_encoding, bits<5> policy_encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class RPRFM<string name, bits<1> type_encoding, bits<5> policy_encoding> { string Name = name; bits<6> Encoding; let Encoding{0} = type_encoding; @@ -336,6 +500,27 @@ class RPRFM<string name, bits<1> type_encoding, bits<5> policy_encoding> : Searc code Requires = [{ {} }]; } +def RPRFMValues : GenericEnum { + let FilterClass = "RPRFM"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def RPRFMsList : GenericTable { + let FilterClass = "RPRFM"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupRPRFMByName : SearchIndex { + let Table = RPRFMsList; + let Key = ["Name"]; +} + +def lookupRPRFMByEncoding : SearchIndex { + let Table = RPRFMsList; + let Key = ["Encoding"]; +} + def : RPRFM<"pldkeep", 0b0, 0b00000>; def : RPRFM<"pstkeep", 0b1, 0b00000>; def : RPRFM<"pldstrm", 0b0, 0b00010>; @@ -345,15 +530,33 @@ def : RPRFM<"pststrm", 0b1, 0b00010>; // SVE Predicate patterns //===----------------------------------------------------------------------===// -class SVEPREDPAT<string name, bits<5> encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class SVEPREDPAT<string name, bits<5> encoding> { string Name = name; bits<5> Encoding; let Encoding = encoding; } +def SVEPREDPATValues : GenericEnum { + let FilterClass = "SVEPREDPAT"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def SVEPREDPATsList : GenericTable { + let FilterClass = "SVEPREDPAT"; + let Fields = ["Name", "Encoding"]; +} + +def lookupSVEPREDPATByName : SearchIndex { + let Table = SVEPREDPATsList; + let Key = ["Name"]; +} + +def lookupSVEPREDPATByEncoding : SearchIndex { + let Table = SVEPREDPATsList; + let Key = ["Encoding"]; +} + def : SVEPREDPAT<"pow2", 0x00>; def : SVEPREDPAT<"vl1", 0x01>; def : SVEPREDPAT<"vl2", 0x02>; @@ -376,15 +579,33 @@ def : SVEPREDPAT<"all", 0x1f>; // SVE Predicate-as-counter patterns //===----------------------------------------------------------------------===// -class SVEVECLENSPECIFIER<string name, bits<1> encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class SVEVECLENSPECIFIER<string name, bits<1> encoding> { string Name = name; bits<1> Encoding; let Encoding = encoding; } +def SVEVECLENSPECIFIERValues : GenericEnum { + let FilterClass = "SVEVECLENSPECIFIER"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def SVEVECLENSPECIFIERsList : GenericTable { + let FilterClass = "SVEVECLENSPECIFIER"; + let Fields = ["Name", "Encoding"]; +} + +def lookupSVEVECLENSPECIFIERByName : SearchIndex { + let Table = SVEVECLENSPECIFIERsList; + let Key = ["Name"]; +} + +def lookupSVEVECLENSPECIFIERByEncoding : SearchIndex { + let Table = SVEVECLENSPECIFIERsList; + let Key = ["Encoding"]; +} + def : SVEVECLENSPECIFIER<"vlx2", 0x0>; def : SVEVECLENSPECIFIER<"vlx4", 0x1>; @@ -395,15 +616,28 @@ def : SVEVECLENSPECIFIER<"vlx4", 0x1>; // is used for a few instructions that only accept a limited set of exact FP // immediates values. //===----------------------------------------------------------------------===// -class ExactFPImm<string name, string repr, bits<4> enum > : SearchableTable { - let SearchableFields = ["Enum", "Repr"]; - let EnumValueField = "Enum"; - +class ExactFPImm<string name, string repr, bits<4> enum > { string Name = name; bits<4> Enum = enum; string Repr = repr; } +def ExactFPImmValues : GenericEnum { + let FilterClass = "ExactFPImm"; + let NameField = "Name"; + let ValueField = "Enum"; +} + +def ExactFPImmsList : GenericTable { + let FilterClass = "ExactFPImm"; + let Fields = ["Enum", "Repr"]; +} + +def lookupExactFPImmByEnum : SearchIndex { + let Table = ExactFPImmsList; + let Key = ["Enum"]; +} + def : ExactFPImm<"zero", "0.0", 0x0>; def : ExactFPImm<"half", "0.5", 0x1>; def : ExactFPImm<"one", "1.0", 0x2>; @@ -413,10 +647,7 @@ def : ExactFPImm<"two", "2.0", 0x3>; // PState instruction options. //===----------------------------------------------------------------------===// -class PStateImm0_15<string name, bits<3> op1, bits<3> op2> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class PStateImm0_15<string name, bits<3> op1, bits<3> op2> { string Name = name; bits<6> Encoding; let Encoding{5-3} = op1; @@ -424,10 +655,28 @@ class PStateImm0_15<string name, bits<3> op1, bits<3> op2> : SearchableTable { code Requires = [{ {} }]; } -class PStateImm0_1<string name, bits<3> op1, bits<3> op2, bits<3> crm_high> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; +def PStateImm0_15Values : GenericEnum { + let FilterClass = "PStateImm0_15"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def PStateImm0_15sList : GenericTable { + let FilterClass = "PStateImm0_15"; + let Fields = ["Name", "Encoding", "Requires"]; +} +def lookupPStateImm0_15ByName : SearchIndex { + let Table = PStateImm0_15sList; + let Key = ["Name"]; +} + +def lookupPStateImm0_15ByEncoding : SearchIndex { + let Table = PStateImm0_15sList; + let Key = ["Encoding"]; +} + +class PStateImm0_1<string name, bits<3> op1, bits<3> op2, bits<3> crm_high> { string Name = name; bits<9> Encoding; let Encoding{8-6} = crm_high; @@ -436,6 +685,27 @@ class PStateImm0_1<string name, bits<3> op1, bits<3> op2, bits<3> crm_high> : Se code Requires = [{ {} }]; } +def PStateImm0_1Values : GenericEnum { + let FilterClass = "PStateImm0_1"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def PStateImm0_1sList : GenericTable { + let FilterClass = "PStateImm0_1"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupPStateImm0_1ByName : SearchIndex { + let Table = PStateImm0_1sList; + let Key = ["Name"]; +} + +def lookupPStateImm0_1ByEncoding : SearchIndex { + let Table = PStateImm0_1sList; + let Key = ["Encoding"]; +} + // Name, Op1, Op2 def : PStateImm0_15<"SPSel", 0b000, 0b101>; def : PStateImm0_15<"DAIFSet", 0b011, 0b110>; @@ -467,16 +737,34 @@ def : PStateImm0_1<"PM", 0b001, 0b000, 0b001>; // SVCR instruction options. //===----------------------------------------------------------------------===// -class SVCR<string name, bits<3> encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class SVCR<string name, bits<3> encoding> { string Name = name; bits<3> Encoding; let Encoding = encoding; code Requires = [{ {} }]; } +def SVCRValues : GenericEnum { + let FilterClass = "SVCR"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def SVCRsList : GenericTable { + let FilterClass = "SVCR"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupSVCRByName : SearchIndex { + let Table = SVCRsList; + let Key = ["Name"]; +} + +def lookupSVCRByEncoding : SearchIndex { + let Table = SVCRsList; + let Key = ["Encoding"]; +} + let Requires = [{ {AArch64::FeatureSME} }] in { def : SVCR<"SVCRSM", 0b001>; def : SVCR<"SVCRZA", 0b010>; @@ -487,30 +775,66 @@ def : SVCR<"SVCRSMZA", 0b011>; // PSB instruction options. //===----------------------------------------------------------------------===// -class PSB<string name, bits<5> encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class PSB<string name, bits<5> encoding> { string Name = name; bits<5> Encoding; let Encoding = encoding; } +def PSBValues : GenericEnum { + let FilterClass = "PSB"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def PSBsList : GenericTable { + let FilterClass = "PSB"; + let Fields = ["Name", "Encoding"]; +} + +def lookupPSBByName : SearchIndex { + let Table = PSBsList; + let Key = ["Name"]; +} + +def lookupPSBByEncoding : SearchIndex { + let Table = PSBsList; + let Key = ["Encoding"]; +} + def : PSB<"csync", 0x11>; //===----------------------------------------------------------------------===// // BTI instruction options. //===----------------------------------------------------------------------===// -class BTI<string name, bits<3> encoding> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - +class BTI<string name, bits<3> encoding> { string Name = name; bits<3> Encoding; let Encoding = encoding; } +def BTIValues : GenericEnum { + let FilterClass = "BTI"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def BTIsList : GenericTable { + let FilterClass = "BTI"; + let Fields = ["Name", "Encoding"]; +} + +def lookupBTIByName : SearchIndex { + let Table = BTIsList; + let Key = ["Name"]; +} + +def lookupBTIByEncoding : SearchIndex { + let Table = BTIsList; + let Key = ["Encoding"]; +} + def : BTI<"c", 0b010>; def : BTI<"j", 0b100>; def : BTI<"jc", 0b110>; @@ -667,12 +991,8 @@ defm : TLBI<"VMALLWS2E1OS", 0b100, 0b1000, 0b0101, 0b010, 0>; //===----------------------------------------------------------------------===// class SysReg<string name, bits<2> op0, bits<3> op1, bits<4> crn, bits<4> crm, - bits<3> op2> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - + bits<3> op2> { string Name = name; - string AltName = name; bits<16> Encoding; let Encoding{15-14} = op0; let Encoding{13-11} = op1; @@ -684,6 +1004,26 @@ class SysReg<string name, bits<2> op0, bits<3> op1, bits<4> crn, bits<4> crm, code Requires = [{ {} }]; } +def SysRegValues : GenericEnum { + let FilterClass = "SysReg"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def SysRegsList : GenericTable { + let FilterClass = "SysReg"; + let Fields = ["Name", "Encoding", "Readable", "Writeable", "Requires"]; + + let PrimaryKey = ["Encoding"]; + let PrimaryKeyName = "lookupSysRegByEncoding"; + let PrimaryKeyReturnRange = true; +} + +def lookupSysRegByName : SearchIndex { + let Table = SysRegsList; + let Key = ["Name"]; +} + class RWSysReg<string name, bits<2> op0, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2> : SysReg<name, op0, op1, crn, crm, op2> { @@ -969,9 +1309,7 @@ def : RWSysReg<"TTBR0_EL1", 0b11, 0b000, 0b0010, 0b0000, 0b000>; def : RWSysReg<"TTBR0_EL3", 0b11, 0b110, 0b0010, 0b0000, 0b000>; let Requires = [{ {AArch64::FeatureEL2VMSA} }] in { -def : RWSysReg<"TTBR0_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b000> { - let AltName = "VSCTLR_EL2"; -} +def : RWSysReg<"TTBR0_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b000>; def : RWSysReg<"VTTBR_EL2", 0b11, 0b100, 0b0010, 0b0001, 0b000>; } @@ -1358,9 +1696,7 @@ def : RWSysReg<"ICH_LR15_EL2", 0b11, 0b100, 0b1100, 0b1101, 0b111>; let Requires = [{ {AArch64::HasV8_0rOps} }] in { //Virtualization System Control Register // Op0 Op1 CRn CRm Op2 -def : RWSysReg<"VSCTLR_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b000> { - let AltName = "TTBR0_EL2"; -} +def : RWSysReg<"VSCTLR_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b000>; //MPU Type Register // Op0 Op1 CRn CRm Op2 @@ -2026,12 +2362,8 @@ def : RWSysReg<"ACTLRALIAS_EL1", 0b11, 0b000, 0b0001, 0b0100, 0b101>; //===----------------------------------------------------------------------===// class PHint<bits<2> op0, bits<3> op1, bits<4> crn, bits<4> crm, - bits<3> op2, string name> : SearchableTable { - let SearchableFields = ["Name", "Encoding"]; - let EnumValueField = "Encoding"; - + bits<3> op2, string name> { string Name = name; - string AltName = name; bits<16> Encoding; let Encoding{15-14} = op0; let Encoding{13-11} = op1; @@ -2041,6 +2373,27 @@ class PHint<bits<2> op0, bits<3> op1, bits<4> crn, bits<4> crm, code Requires = [{ {} }]; } +def PHintValues : GenericEnum { + let FilterClass = "PHint"; + let NameField = "Name"; + let ValueField = "Encoding"; +} + +def PHintsList : GenericTable { + let FilterClass = "PHint"; + let Fields = ["Name", "Encoding", "Requires"]; +} + +def lookupPHintByName : SearchIndex { + let Table = PHintsList; + let Key = ["Name"]; +} + +def lookupPHintByEncoding : SearchIndex { + let Table = PHintsList; + let Key = ["Encoding"]; +} + let Requires = [{ {AArch64::FeaturePCDPHINT} }] in { def KEEP : PHint<0b00, 0b000, 0b0000, 0b0000, 0b000, "keep">; def STRM : PHint<0b00, 0b000, 0b0000, 0b0000, 0b001, "strm">; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 0566a87..25b6731 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -35,6 +35,9 @@ using namespace llvm::PatternMatch; static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix", cl::init(true), cl::Hidden); +static cl::opt<bool> SVEPreferFixedOverScalableIfEqualCost( + "sve-prefer-fixed-over-scalable-if-equal", cl::Hidden); + static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10), cl::Hidden); @@ -256,7 +259,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, CalleeAttrs.set(SMEAttrs::SM_Enabled, true); } - if (CalleeAttrs.isNewZA()) + if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0()) return false; if (CallerAttrs.requiresLazySave(CalleeAttrs) || @@ -1635,10 +1638,8 @@ instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) { !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( m_ConstantInt<AArch64SVEPredPattern::all>()))) return std::nullopt; - IRBuilderBase::FastMathFlagGuard FMFGuard(IC.Builder); - IC.Builder.setFastMathFlags(II.getFastMathFlags()); - auto BinOp = - IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2)); + auto BinOp = IC.Builder.CreateBinOpFMF( + BinOpCode, II.getOperand(1), II.getOperand(2), II.getFastMathFlags()); return IC.replaceInstUsesWith(II, BinOp); } @@ -2760,6 +2761,21 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, return AdjustCost( BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); + static const TypeConversionCostTblEntry BF16Tbl[] = { + {ISD::FP_ROUND, MVT::bf16, MVT::f32, 1}, // bfcvt + {ISD::FP_ROUND, MVT::bf16, MVT::f64, 1}, // bfcvt + {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 1}, // bfcvtn + {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 2}, // bfcvtn+bfcvtn2 + {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 2}, // bfcvtn+fcvtn + {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 3}, // fcvtn+fcvtl2+bfcvtn + {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+bfcvtn + }; + + if (ST->hasBF16()) + if (const auto *Entry = ConvertCostTableLookup( + BF16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) + return AdjustCost(Entry->Cost); + static const TypeConversionCostTblEntry ConversionTbl[] = { {ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1}, // xtn {ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1}, // xtn @@ -2847,6 +2863,14 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f16, 2}, // fcvtl+fcvtl {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, 3}, // fcvtl+fcvtl2+fcvtl {ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, 6}, // 2 * fcvtl+fcvtl2+fcvtl + // BF16 (uses shift) + {ISD::FP_EXTEND, MVT::f32, MVT::bf16, 1}, // shl + {ISD::FP_EXTEND, MVT::f64, MVT::bf16, 2}, // shl+fcvt + {ISD::FP_EXTEND, MVT::v4f32, MVT::v4bf16, 1}, // shll + {ISD::FP_EXTEND, MVT::v8f32, MVT::v8bf16, 2}, // shll+shll2 + {ISD::FP_EXTEND, MVT::v2f64, MVT::v2bf16, 2}, // shll+fcvtl + {ISD::FP_EXTEND, MVT::v4f64, MVT::v4bf16, 3}, // shll+fcvtl+fcvtl2 + {ISD::FP_EXTEND, MVT::v8f64, MVT::v8bf16, 6}, // 2 * shll+fcvtl+fcvtl2 // FP Ext and trunc {ISD::FP_ROUND, MVT::f32, MVT::f64, 1}, // fcvt {ISD::FP_ROUND, MVT::v2f32, MVT::v2f64, 1}, // fcvtn @@ -2859,6 +2883,15 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, {ISD::FP_ROUND, MVT::v2f16, MVT::v2f64, 2}, // fcvtn+fcvtn {ISD::FP_ROUND, MVT::v4f16, MVT::v4f64, 3}, // fcvtn+fcvtn2+fcvtn {ISD::FP_ROUND, MVT::v8f16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+fcvtn + // BF16 (more complex, with +bf16 is handled above) + {ISD::FP_ROUND, MVT::bf16, MVT::f32, 8}, // Expansion is ~8 insns + {ISD::FP_ROUND, MVT::bf16, MVT::f64, 9}, // fcvtn + above + {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f32, 8}, + {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 8}, + {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 15}, + {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 9}, + {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 10}, + {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 19}, // LowerVectorINT_TO_FP: {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1}, @@ -4705,10 +4738,21 @@ InstructionCost AArch64TTIImpl::getShuffleCost( } Kind = improveShuffleKindFromMask(Kind, Mask, Tp, Index, SubTp); - // Treat extractsubvector as single op permutation. bool IsExtractSubvector = Kind == TTI::SK_ExtractSubvector; - if (IsExtractSubvector && LT.second.isFixedLengthVector()) + // A sebvector extract can be implemented with a ext (or trivial extract, if + // from lane 0). This currently only handles low or high extracts to prevent + // SLP vectorizer regressions. + if (IsExtractSubvector && LT.second.isFixedLengthVector()) { + if (LT.second.is128BitVector() && + cast<FixedVectorType>(SubTp)->getNumElements() == + LT.second.getVectorNumElements() / 2) { + if (Index == 0) + return 0; + if (Index == (int)LT.second.getVectorNumElements() / 2) + return 1; + } Kind = TTI::SK_PermuteSingleSrc; + } // Check for broadcast loads, which are supported by the LD1R instruction. // In terms of code-size, the shuffle vector is free when a load + dup get @@ -4919,6 +4963,12 @@ static bool containsDecreasingPointers(Loop *TheLoop, return false; } +bool AArch64TTIImpl::preferFixedOverScalableIfEqualCost() const { + if (SVEPreferFixedOverScalableIfEqualCost.getNumOccurrences()) + return SVEPreferFixedOverScalableIfEqualCost; + return ST->useFixedOverScalableIfEqualCost(); +} + unsigned AArch64TTIImpl::getEpilogueVectorizationMinVF() const { return ST->getEpilogueVectorizationMinVF(); } @@ -5283,11 +5333,17 @@ bool AArch64TTIImpl::isProfitableToSinkOperands( } } - // Sink vscales closer to uses for better isel + auto ShouldSinkCondition = [](Value *Cond) -> bool { + auto *II = dyn_cast<IntrinsicInst>(Cond); + return II && II->getIntrinsicID() == Intrinsic::vector_reduce_or && + isa<ScalableVectorType>(II->getOperand(0)->getType()); + }; + switch (I->getOpcode()) { case Instruction::GetElementPtr: case Instruction::Add: case Instruction::Sub: + // Sink vscales closer to uses for better isel for (unsigned Op = 0; Op < I->getNumOperands(); ++Op) { if (shouldSinkVScale(I->getOperand(Op), Ops)) { Ops.push_back(&I->getOperandUse(Op)); @@ -5295,6 +5351,23 @@ bool AArch64TTIImpl::isProfitableToSinkOperands( } } break; + case Instruction::Select: { + if (!ShouldSinkCondition(I->getOperand(0))) + return false; + + Ops.push_back(&I->getOperandUse(0)); + return true; + } + case Instruction::Br: { + if (cast<BranchInst>(I)->isUnconditional()) + return false; + + if (!ShouldSinkCondition(cast<BranchInst>(I)->getCondition())) + return false; + + Ops.push_back(&I->getOperandUse(0)); + return true; + } default: break; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 83b86e3..214fb4e 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -387,9 +387,7 @@ public: return TailFoldingStyle::DataWithoutLaneMask; } - bool preferFixedOverScalableIfEqualCost() const { - return ST->useFixedOverScalableIfEqualCost(); - } + bool preferFixedOverScalableIfEqualCost() const; unsigned getEpilogueVectorizationMinVF() const; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp index 4b7d415..93461e3 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp @@ -454,6 +454,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) {nxv2s64, p0, nxv2s64, 8}, }) .clampScalar(0, s8, s64) + .minScalarOrElt(0, s8) .lowerIf([=](const LegalityQuery &Query) { return Query.Types[0].isScalar() && Query.Types[0] != Query.MMODescrs[0].MemoryTy; @@ -466,14 +467,19 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .clampMaxNumElements(0, p0, 2) .lowerIfMemSizeNotPow2() // TODO: Use BITCAST for v2i8, v2i16 after G_TRUNC gets sorted out - .bitcastIf(typeInSet(0, {v4s8}), + .bitcastIf(all(typeInSet(0, {v4s8}), + LegalityPredicate([=](const LegalityQuery &Query) { + return Query.Types[0].getSizeInBits() == + Query.MMODescrs[0].MemoryTy.getSizeInBits(); + })), [=](const LegalityQuery &Query) { const LLT VecTy = Query.Types[0]; return std::pair(0, LLT::scalar(VecTy.getSizeInBits())); }) .customIf(IsPtrVecPred) .scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0) - .scalarizeIf(scalarOrEltWiderThan(0, 64), 0); + .scalarizeIf(scalarOrEltWiderThan(0, 64), 0) + .lower(); getActionDefinitionsBuilder(G_INDEXED_STORE) // Idx 0 == Ptr, Idx 1 == Val @@ -861,6 +867,13 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .legalForCartesianProduct({s32, v2s16, v4s8}) .legalForCartesianProduct({s64, v8s8, v4s16, v2s32}) .legalForCartesianProduct({s128, v16s8, v8s16, v4s32, v2s64, v2p0}) + .customIf([=](const LegalityQuery &Query) { + // Handle casts from i1 vectors to scalars. + LLT DstTy = Query.Types[0]; + LLT SrcTy = Query.Types[1]; + return DstTy.isScalar() && SrcTy.isVector() && + SrcTy.getScalarSizeInBits() == 1; + }) .lowerIf([=](const LegalityQuery &Query) { return Query.Types[0].isVector() != Query.Types[1].isVector(); }) @@ -1062,10 +1075,11 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) return llvm::is_contained( {v2s64, v2s32, v4s32, v4s16, v16s8, v8s8, v8s16}, DstTy); }) - // G_SHUFFLE_VECTOR can have scalar sources (from 1 x s vectors), we - // just want those lowered into G_BUILD_VECTOR + // G_SHUFFLE_VECTOR can have scalar sources (from 1 x s vectors) or scalar + // destinations, we just want those lowered into G_BUILD_VECTOR or + // G_EXTRACT_ELEMENT. .lowerIf([=](const LegalityQuery &Query) { - return !Query.Types[1].isVector(); + return !Query.Types[0].isVector() || !Query.Types[1].isVector(); }) .moreElementsIf( [](const LegalityQuery &Query) { @@ -1404,11 +1418,28 @@ bool AArch64LegalizerInfo::legalizeCustom( return Helper.lowerAbsToCNeg(MI); case TargetOpcode::G_ICMP: return legalizeICMP(MI, MRI, MIRBuilder); + case TargetOpcode::G_BITCAST: + return legalizeBitcast(MI, Helper); } llvm_unreachable("expected switch to return"); } +bool AArch64LegalizerInfo::legalizeBitcast(MachineInstr &MI, + LegalizerHelper &Helper) const { + assert(MI.getOpcode() == TargetOpcode::G_BITCAST && "Unexpected opcode"); + auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs(); + // We're trying to handle casts from i1 vectors to scalars but reloading from + // stack. + if (!DstTy.isScalar() || !SrcTy.isVector() || + SrcTy.getElementType() != LLT::scalar(1)) + return false; + + Helper.createStackStoreLoad(DstReg, SrcReg); + MI.eraseFromParent(); + return true; +} + bool AArch64LegalizerInfo::legalizeFunnelShift(MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &MIRBuilder, diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h index 00d85a3..bcb29432 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h @@ -66,6 +66,7 @@ private: LegalizerHelper &Helper) const; bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const; bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const; + bool legalizeBitcast(MachineInstr &MI, LegalizerHelper &Helper) const; const AArch64Subtarget *ST; }; } // End llvm namespace. diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp index 5fe2e3c..6bba70d 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp @@ -405,6 +405,19 @@ void applyEXT(MachineInstr &MI, ShuffleVectorPseudo &MatchInfo) { MI.eraseFromParent(); } +void applyFullRev(MachineInstr &MI, MachineRegisterInfo &MRI) { + Register Dst = MI.getOperand(0).getReg(); + Register Src = MI.getOperand(1).getReg(); + LLT DstTy = MRI.getType(Dst); + assert(DstTy.getSizeInBits() == 128 && + "Expected 128bit vector in applyFullRev"); + MachineIRBuilder MIRBuilder(MI); + auto Cst = MIRBuilder.buildConstant(LLT::scalar(32), 8); + auto Rev = MIRBuilder.buildInstr(AArch64::G_REV64, {DstTy}, {Src}); + MIRBuilder.buildInstr(AArch64::G_EXT, {Dst}, {Rev, Rev, Cst}); + MI.eraseFromParent(); +} + bool matchNonConstInsert(MachineInstr &MI, MachineRegisterInfo &MRI) { assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp index ae84bc9..875b505 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp @@ -1874,26 +1874,25 @@ void AArch64InstPrinter::printBarriernXSOption(const MCInst *MI, unsigned OpNo, markup(O, Markup::Immediate) << "#" << Val; } -static bool isValidSysReg(const AArch64SysReg::SysReg *Reg, bool Read, +static bool isValidSysReg(const AArch64SysReg::SysReg &Reg, bool Read, const MCSubtargetInfo &STI) { - return (Reg && (Read ? Reg->Readable : Reg->Writeable) && - Reg->haveFeatures(STI.getFeatureBits())); + return (Read ? Reg.Readable : Reg.Writeable) && + Reg.haveFeatures(STI.getFeatureBits()); } -// Looks up a system register either by encoding or by name. Some system +// Looks up a system register either by encoding. Some system // registers share the same encoding between different architectures, -// therefore a tablegen lookup by encoding will return an entry regardless -// of the register's predication on a specific subtarget feature. To work -// around this problem we keep an alternative name for such registers and -// look them up by that name if the first lookup was unsuccessful. +// to work around this tablegen will return a range of registers with the same +// encodings. We need to check each register in the range to see if it valid. static const AArch64SysReg::SysReg *lookupSysReg(unsigned Val, bool Read, const MCSubtargetInfo &STI) { - const AArch64SysReg::SysReg *Reg = AArch64SysReg::lookupSysRegByEncoding(Val); - - if (Reg && !isValidSysReg(Reg, Read, STI)) - Reg = AArch64SysReg::lookupSysRegByName(Reg->AltName); + auto Range = AArch64SysReg::lookupSysRegByEncoding(Val); + for (auto &Reg : Range) { + if (isValidSysReg(Reg, Read, STI)) + return &Reg; + } - return Reg; + return nullptr; } void AArch64InstPrinter::printMRSSystemRegister(const MCInst *MI, unsigned OpNo, @@ -1917,7 +1916,7 @@ void AArch64InstPrinter::printMRSSystemRegister(const MCInst *MI, unsigned OpNo, const AArch64SysReg::SysReg *Reg = lookupSysReg(Val, true /*Read*/, STI); - if (isValidSysReg(Reg, true /*Read*/, STI)) + if (Reg) O << Reg->Name; else O << AArch64SysReg::genericRegisterString(Val); @@ -1944,7 +1943,7 @@ void AArch64InstPrinter::printMSRSystemRegister(const MCInst *MI, unsigned OpNo, const AArch64SysReg::SysReg *Reg = lookupSysReg(Val, false /*Read*/, STI); - if (isValidSysReg(Reg, false /*Read*/, STI)) + if (Reg) O << Reg->Name; else O << AArch64SysReg::genericRegisterString(Val); diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp index d83c22e..7767028 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp @@ -18,7 +18,7 @@ using namespace llvm; namespace llvm { namespace AArch64AT { -#define GET_AT_IMPL +#define GET_ATsList_IMPL #include "AArch64GenSystemOperands.inc" } } @@ -26,128 +26,121 @@ namespace llvm { namespace llvm { namespace AArch64DBnXS { -#define GET_DBNXS_IMPL +#define GET_DBnXSsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64DB { -#define GET_DB_IMPL +#define GET_DBsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64DC { -#define GET_DC_IMPL +#define GET_DCsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64IC { -#define GET_IC_IMPL +#define GET_ICsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64ISB { -#define GET_ISB_IMPL +#define GET_ISBsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64TSB { -#define GET_TSB_IMPL -#include "AArch64GenSystemOperands.inc" - } -} - -namespace llvm { - namespace AArch64PRCTX { -#define GET_PRCTX_IMPL +#define GET_TSBsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64PRFM { -#define GET_PRFM_IMPL +#define GET_PRFMsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64SVEPRFM { -#define GET_SVEPRFM_IMPL +#define GET_SVEPRFMsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64RPRFM { -#define GET_RPRFM_IMPL +#define GET_RPRFMsList_IMPL #include "AArch64GenSystemOperands.inc" } // namespace AArch64RPRFM } // namespace llvm namespace llvm { namespace AArch64SVEPredPattern { -#define GET_SVEPREDPAT_IMPL +#define GET_SVEPREDPATsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64SVEVecLenSpecifier { -#define GET_SVEVECLENSPECIFIER_IMPL +#define GET_SVEVECLENSPECIFIERsList_IMPL #include "AArch64GenSystemOperands.inc" } // namespace AArch64SVEVecLenSpecifier } // namespace llvm namespace llvm { namespace AArch64ExactFPImm { -#define GET_EXACTFPIMM_IMPL +#define GET_ExactFPImmsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64PState { -#define GET_PSTATEIMM0_15_IMPL +#define GET_PStateImm0_15sList_IMPL #include "AArch64GenSystemOperands.inc" -#define GET_PSTATEIMM0_1_IMPL +#define GET_PStateImm0_1sList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64PSBHint { -#define GET_PSB_IMPL +#define GET_PSBsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64PHint { -#define GET_PHINT_IMPL +#define GET_PHintsList_IMPL #include "AArch64GenSystemOperands.inc" } // namespace AArch64PHint } // namespace llvm namespace llvm { namespace AArch64BTIHint { -#define GET_BTI_IMPL +#define GET_BTIsList_IMPL #include "AArch64GenSystemOperands.inc" } } namespace llvm { namespace AArch64SysReg { -#define GET_SYSREG_IMPL +#define GET_SysRegsList_IMPL #include "AArch64GenSystemOperands.inc" } } @@ -194,7 +187,7 @@ namespace llvm { namespace llvm { namespace AArch64SVCR { -#define GET_SVCR_IMPL +#define GET_SVCRsList_IMPL #include "AArch64GenSystemOperands.inc" } } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h index e0ccba4..b8d3236 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -371,79 +371,89 @@ namespace AArch64SVCR { struct SVCR : SysAlias{ using SysAlias::SysAlias; }; - #define GET_SVCR_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_SVCRValues_DECL +#define GET_SVCRsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64AT{ struct AT : SysAlias { using SysAlias::SysAlias; }; - #define GET_AT_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_ATValues_DECL +#define GET_ATsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64DB { struct DB : SysAlias { using SysAlias::SysAlias; }; - #define GET_DB_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_DBValues_DECL +#define GET_DBsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64DBnXS { struct DBnXS : SysAliasImm { using SysAliasImm::SysAliasImm; }; - #define GET_DBNXS_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_DBnXSValues_DECL +#define GET_DBnXSsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64DC { struct DC : SysAlias { using SysAlias::SysAlias; }; - #define GET_DC_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_DCValues_DECL +#define GET_DCsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64IC { struct IC : SysAliasReg { using SysAliasReg::SysAliasReg; }; - #define GET_IC_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_ICValues_DECL +#define GET_ICsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64ISB { struct ISB : SysAlias { using SysAlias::SysAlias; }; - #define GET_ISB_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_ISBValues_DECL +#define GET_ISBsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64TSB { struct TSB : SysAlias { using SysAlias::SysAlias; }; - #define GET_TSB_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_TSBValues_DECL +#define GET_TSBsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64PRFM { struct PRFM : SysAlias { using SysAlias::SysAlias; }; - #define GET_PRFM_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_PRFMValues_DECL +#define GET_PRFMsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64SVEPRFM { struct SVEPRFM : SysAlias { using SysAlias::SysAlias; }; -#define GET_SVEPRFM_DECL +#define GET_SVEPRFMValues_DECL +#define GET_SVEPRFMsList_DECL #include "AArch64GenSystemOperands.inc" } @@ -451,7 +461,8 @@ namespace AArch64RPRFM { struct RPRFM : SysAlias { using SysAlias::SysAlias; }; -#define GET_RPRFM_DECL +#define GET_RPRFMValues_DECL +#define GET_RPRFMsList_DECL #include "AArch64GenSystemOperands.inc" } // namespace AArch64RPRFM @@ -460,7 +471,8 @@ namespace AArch64SVEPredPattern { const char *Name; uint16_t Encoding; }; -#define GET_SVEPREDPAT_DECL +#define GET_SVEPREDPATValues_DECL +#define GET_SVEPREDPATsList_DECL #include "AArch64GenSystemOperands.inc" } @@ -469,7 +481,8 @@ namespace AArch64SVEVecLenSpecifier { const char *Name; uint16_t Encoding; }; -#define GET_SVEVECLENSPECIFIER_DECL +#define GET_SVEVECLENSPECIFIERValues_DECL +#define GET_SVEVECLENSPECIFIERsList_DECL #include "AArch64GenSystemOperands.inc" } // namespace AArch64SVEVecLenSpecifier @@ -551,12 +564,12 @@ LLVM_DECLARE_ENUM_AS_BITMASK(TailFoldingOpts, /* LargestValue */ (long)TailFoldingOpts::Reverse); namespace AArch64ExactFPImm { - struct ExactFPImm { - const char *Name; - int Enum; - const char *Repr; - }; -#define GET_EXACTFPIMM_DECL +struct ExactFPImm { + int Enum; + const char *Repr; +}; +#define GET_ExactFPImmValues_DECL +#define GET_ExactFPImmsList_DECL #include "AArch64GenSystemOperands.inc" } @@ -564,28 +577,30 @@ namespace AArch64PState { struct PStateImm0_15 : SysAlias{ using SysAlias::SysAlias; }; - #define GET_PSTATEIMM0_15_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_PStateImm0_15Values_DECL +#define GET_PStateImm0_15sList_DECL +#include "AArch64GenSystemOperands.inc" struct PStateImm0_1 : SysAlias{ using SysAlias::SysAlias; }; - #define GET_PSTATEIMM0_1_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_PStateImm0_1Values_DECL +#define GET_PStateImm0_1sList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64PSBHint { struct PSB : SysAlias { using SysAlias::SysAlias; }; - #define GET_PSB_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_PSBValues_DECL +#define GET_PSBsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64PHint { struct PHint { const char *Name; - const char *AltName; unsigned Encoding; FeatureBitset FeaturesRequired; @@ -595,7 +610,8 @@ struct PHint { } }; -#define GET_PHINT_DECL +#define GET_PHintValues_DECL +#define GET_PHintsList_DECL #include "AArch64GenSystemOperands.inc" const PHint *lookupPHintByName(StringRef); @@ -606,8 +622,9 @@ namespace AArch64BTIHint { struct BTI : SysAlias { using SysAlias::SysAlias; }; - #define GET_BTI_DECL - #include "AArch64GenSystemOperands.inc" +#define GET_BTIValues_DECL +#define GET_BTIsList_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64SME { @@ -701,7 +718,6 @@ AArch64StringToVectorLayout(StringRef LayoutStr) { namespace AArch64SysReg { struct SysReg { const char Name[32]; - const char AltName[32]; unsigned Encoding; bool Readable; bool Writeable; @@ -713,11 +729,9 @@ namespace AArch64SysReg { } }; - #define GET_SYSREG_DECL - #include "AArch64GenSystemOperands.inc" - - const SysReg *lookupSysRegByName(StringRef); - const SysReg *lookupSysRegByEncoding(uint16_t); +#define GET_SysRegsList_DECL +#define GET_SysRegValues_DECL +#include "AArch64GenSystemOperands.inc" uint32_t parseGenericRegister(StringRef Name); std::string genericRegisterString(uint32_t Bits); @@ -731,14 +745,6 @@ namespace AArch64TLBI { #include "AArch64GenSystemOperands.inc" } -namespace AArch64PRCTX { - struct PRCTX : SysAliasReg { - using SysAliasReg::SysAliasReg; - }; - #define GET_PRCTX_DECL - #include "AArch64GenSystemOperands.inc" -} - namespace AArch64II { /// Target Operand Flag enum. enum TOF { diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp index e844904..0f97988 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp @@ -1523,7 +1523,8 @@ Value *AMDGPUCodeGenPrepareImpl::shrinkDivRem64(IRBuilder<> &Builder, bool IsDiv = Opc == Instruction::SDiv || Opc == Instruction::UDiv; bool IsSigned = Opc == Instruction::SDiv || Opc == Instruction::SRem; - int NumDivBits = getDivNumBits(I, Num, Den, 32, IsSigned); + unsigned BitWidth = Num->getType()->getScalarSizeInBits(); + int NumDivBits = getDivNumBits(I, Num, Den, BitWidth - 32, IsSigned); if (NumDivBits == -1) return nullptr; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td index 985fa8f..da47aaf 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td @@ -124,6 +124,16 @@ def sign_extension_in_reg : GICombineRule< [{ return matchCombineSignExtendInReg(*${sign_inreg}, ${matchinfo}); }]), (apply [{ applyCombineSignExtendInReg(*${sign_inreg}, ${matchinfo}); }])>; +// Do the following combines : +// fmul x, select(y, A, B) -> fldexp (x, select i32 (y, a, b)) +// fmul x, select(y, -A, -B) -> fldexp ((fneg x), select i32 (y, a, b)) +def combine_fmul_with_select_to_fldexp : GICombineRule< + (defs root:$root, build_fn_matchinfo:$matchinfo), + (match (G_FMUL $dst, $x, $select):$root, + (G_SELECT $select, $y, $A, $B):$sel, + [{ return Helper.matchCombineFmulWithSelectToFldexp(*${root}, *${sel}, ${matchinfo}); }]), + (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>; + let Predicates = [Has16BitInsts, NotHasMed3_16] in { // For gfx8, expand f16-fmed3-as-f32 into a min/max f16 sequence. This @@ -153,13 +163,13 @@ def gfx8_combines : GICombineGroup<[expand_promoted_fmed3]>; def AMDGPUPreLegalizerCombiner: GICombiner< "AMDGPUPreLegalizerCombinerImpl", - [all_combines, clamp_i64_to_i16, foldable_fneg]> { + [all_combines, combine_fmul_with_select_to_fldexp, clamp_i64_to_i16, foldable_fneg]> { let CombineAllMethodName = "tryCombineAllImpl"; } def AMDGPUPostLegalizerCombiner: GICombiner< "AMDGPUPostLegalizerCombinerImpl", - [all_combines, gfx6gfx7_combines, gfx8_combines, + [all_combines, gfx6gfx7_combines, gfx8_combines, combine_fmul_with_select_to_fldexp, uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg, rcp_sqrt_to_rsq, fdiv_by_sqrt_to_rsq_f16, sign_extension_in_reg, smulu64]> { let CombineAllMethodName = "tryCombineAllImpl"; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp index e5a376a..f6f9f4b 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp @@ -17,6 +17,13 @@ using namespace llvm; using namespace MIPatternMatch; +AMDGPUCombinerHelper::AMDGPUCombinerHelper( + GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize, + GISelKnownBits *KB, MachineDominatorTree *MDT, const LegalizerInfo *LI, + const GCNSubtarget &STI) + : CombinerHelper(Observer, B, IsPreLegalize, KB, MDT, LI), STI(STI), + TII(*STI.getInstrInfo()) {} + LLVM_READNONE static bool fnegFoldsIntoMI(const MachineInstr &MI) { switch (MI.getOpcode()) { @@ -445,3 +452,67 @@ void AMDGPUCombinerHelper::applyExpandPromotedF16FMed3(MachineInstr &MI, Builder.buildFMinNumIEEE(MI.getOperand(0), B1, C1); MI.eraseFromParent(); } + +bool AMDGPUCombinerHelper::matchCombineFmulWithSelectToFldexp( + MachineInstr &MI, MachineInstr &Sel, + std::function<void(MachineIRBuilder &)> &MatchInfo) { + assert(MI.getOpcode() == TargetOpcode::G_FMUL); + assert(Sel.getOpcode() == TargetOpcode::G_SELECT); + assert(MI.getOperand(2).getReg() == Sel.getOperand(0).getReg()); + + Register Dst = MI.getOperand(0).getReg(); + LLT DestTy = MRI.getType(Dst); + LLT ScalarDestTy = DestTy.getScalarType(); + + if ((ScalarDestTy != LLT::float64() && ScalarDestTy != LLT::float32() && + ScalarDestTy != LLT::float16()) || + !MRI.hasOneNonDBGUse(Sel.getOperand(0).getReg())) + return false; + + Register SelectCondReg = Sel.getOperand(1).getReg(); + MachineInstr *SelectTrue = MRI.getVRegDef(Sel.getOperand(2).getReg()); + MachineInstr *SelectFalse = MRI.getVRegDef(Sel.getOperand(3).getReg()); + + const auto SelectTrueVal = + isConstantOrConstantSplatVectorFP(*SelectTrue, MRI); + if (!SelectTrueVal) + return false; + const auto SelectFalseVal = + isConstantOrConstantSplatVectorFP(*SelectFalse, MRI); + if (!SelectFalseVal) + return false; + + if (SelectTrueVal->isNegative() != SelectFalseVal->isNegative()) + return false; + + // For f32, only non-inline constants should be transformed. + if (ScalarDestTy == LLT::float32() && TII.isInlineConstant(*SelectTrueVal) && + TII.isInlineConstant(*SelectFalseVal)) + return false; + + int SelectTrueLog2Val = SelectTrueVal->getExactLog2Abs(); + if (SelectTrueLog2Val == INT_MIN) + return false; + int SelectFalseLog2Val = SelectFalseVal->getExactLog2Abs(); + if (SelectFalseLog2Val == INT_MIN) + return false; + + MatchInfo = [=, &MI](MachineIRBuilder &Builder) { + LLT IntDestTy = DestTy.changeElementType(LLT::scalar(32)); + auto NewSel = Builder.buildSelect( + IntDestTy, SelectCondReg, + Builder.buildConstant(IntDestTy, SelectTrueLog2Val), + Builder.buildConstant(IntDestTy, SelectFalseLog2Val)); + + Register XReg = MI.getOperand(1).getReg(); + if (SelectTrueVal->isNegative()) { + auto NegX = + Builder.buildFNeg(DestTy, XReg, MRI.getVRegDef(XReg)->getFlags()); + Builder.buildFLdexp(Dst, NegX, NewSel, MI.getFlags()); + } else { + Builder.buildFLdexp(Dst, XReg, NewSel, MI.getFlags()); + } + }; + + return true; +} diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.h b/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.h index 6510abe..893b3f5 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.h @@ -15,13 +15,22 @@ #ifndef LLVM_LIB_TARGET_AMDGPU_AMDGPUCOMBINERHELPER_H #define LLVM_LIB_TARGET_AMDGPU_AMDGPUCOMBINERHELPER_H +#include "GCNSubtarget.h" #include "llvm/CodeGen/GlobalISel/Combiner.h" #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" namespace llvm { class AMDGPUCombinerHelper : public CombinerHelper { +protected: + const GCNSubtarget &STI; + const SIInstrInfo &TII; + public: using CombinerHelper::CombinerHelper; + AMDGPUCombinerHelper(GISelChangeObserver &Observer, MachineIRBuilder &B, + bool IsPreLegalize, GISelKnownBits *KB, + MachineDominatorTree *MDT, const LegalizerInfo *LI, + const GCNSubtarget &STI); bool matchFoldableFneg(MachineInstr &MI, MachineInstr *&MatchInfo); void applyFoldableFneg(MachineInstr &MI, MachineInstr *&MatchInfo); @@ -30,6 +39,10 @@ public: Register Src1, Register Src2); void applyExpandPromotedF16FMed3(MachineInstr &MI, Register Src0, Register Src1, Register Src2); + + bool matchCombineFmulWithSelectToFldexp( + MachineInstr &MI, MachineInstr &Sel, + std::function<void(MachineIRBuilder &)> &MatchInfo); }; } // namespace llvm diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp index d9eaf82..27e9018 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp @@ -1997,7 +1997,7 @@ bool AMDGPUDAGToDAGISel::SelectScratchSVAddr(SDNode *N, SDValue Addr, if (checkFlatScratchSVSSwizzleBug(VAddr, SAddr, ImmOffset)) return false; SAddr = SelectSAddrFI(CurDAG, SAddr); - Offset = CurDAG->getTargetConstant(ImmOffset, SDLoc(), MVT::i32); + Offset = CurDAG->getSignedTargetConstant(ImmOffset, SDLoc(), MVT::i32); return true; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp index 3be865f..041b9b4 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -1125,8 +1125,9 @@ static int getV_CMPOpcode(CmpInst::Predicate P, unsigned Size, unsigned FakeS16Opc, unsigned S32Opc, unsigned S64Opc) { if (Size == 16) + // FIXME-TRUE16 use TrueS16Opc when realtrue16 is supported for CMP code return ST.hasTrue16BitInsts() - ? ST.useRealTrue16Insts() ? TrueS16Opc : FakeS16Opc + ? ST.useRealTrue16Insts() ? FakeS16Opc : FakeS16Opc : S16Opc; if (Size == 32) return S32Opc; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp index 54d927c..888817e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp @@ -134,7 +134,7 @@ AMDGPUPostLegalizerCombinerImpl::AMDGPUPostLegalizerCombinerImpl( const GCNSubtarget &STI, MachineDominatorTree *MDT, const LegalizerInfo *LI) : Combiner(MF, CInfo, TPC, &KB, CSEInfo), RuleConfig(RuleConfig), STI(STI), TII(*STI.getInstrInfo()), - Helper(Observer, B, /*IsPreLegalize*/ false, &KB, MDT, LI), + Helper(Observer, B, /*IsPreLegalize*/ false, &KB, MDT, LI, STI), #define GET_GICOMBINER_CONSTRUCTOR_INITS #include "AMDGPUGenPostLegalizeGICombiner.inc" #undef GET_GICOMBINER_CONSTRUCTOR_INITS diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPreLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPreLegalizerCombiner.cpp index ff8189c..e1564d5 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUPreLegalizerCombiner.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUPreLegalizerCombiner.cpp @@ -94,7 +94,7 @@ AMDGPUPreLegalizerCombinerImpl::AMDGPUPreLegalizerCombinerImpl( const AMDGPUPreLegalizerCombinerImplRuleConfig &RuleConfig, const GCNSubtarget &STI, MachineDominatorTree *MDT, const LegalizerInfo *LI) : Combiner(MF, CInfo, TPC, &KB, CSEInfo), RuleConfig(RuleConfig), STI(STI), - Helper(Observer, B, /*IsPreLegalize*/ true, &KB, MDT, LI), + Helper(Observer, B, /*IsPreLegalize*/ true, &KB, MDT, LI, STI), #define GET_GICOMBINER_CONSTRUCTOR_INITS #include "AMDGPUGenPreLegalizeGICombiner.inc" #undef GET_GICOMBINER_CONSTRUCTOR_INITS diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp index d94c400..08e23cb 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp @@ -1190,9 +1190,13 @@ bool AMDGPURegisterBankInfo::applyMappingDynStackAlloc( const RegisterBank *SizeBank = getRegBank(AllocSize, MRI, *TRI); - // TODO: Need to emit a wave reduction to get the maximum size. - if (SizeBank != &AMDGPU::SGPRRegBank) - return false; + if (SizeBank != &AMDGPU::SGPRRegBank) { + auto WaveReduction = + B.buildIntrinsic(Intrinsic::amdgcn_wave_reduce_umax, {LLT::scalar(32)}) + .addUse(AllocSize) + .addImm(0); + AllocSize = WaveReduction.getReg(0); + } LLT PtrTy = MRI.getType(Dst); LLT IntPtrTy = LLT::scalar(PtrTy.getSizeInBits()); diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index ed956a1..d8f441d 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -9760,10 +9760,14 @@ unsigned AMDGPUAsmParser::validateTargetOperandClass(MCParsedAsmOperand &Op, case MCK_SReg_64: case MCK_SReg_64_XEXEC: // Null is defined as a 32-bit register but - // it should also be enabled with 64-bit operands. - // The following code enables it for SReg_64 operands + // it should also be enabled with 64-bit operands or larger. + // The following code enables it for SReg_64 and larger operands // used as source and destination. Remaining source // operands are handled in isInlinableImm. + case MCK_SReg_96: + case MCK_SReg_128: + case MCK_SReg_256: + case MCK_SReg_512: return Operand.isNull() ? Match_Success : Match_InvalidOperand; default: return Match_InvalidOperand; diff --git a/llvm/lib/Target/AMDGPU/BUFInstructions.td b/llvm/lib/Target/AMDGPU/BUFInstructions.td index a351f45..f2686bd 100644 --- a/llvm/lib/Target/AMDGPU/BUFInstructions.td +++ b/llvm/lib/Target/AMDGPU/BUFInstructions.td @@ -168,7 +168,7 @@ class getMTBUFInsDA<list<RegisterClass> vdataList, dag SOffset = !if(hasRestrictedSOffset, (ins SReg_32:$soffset), (ins SCSrc_b32:$soffset)); - dag NonVaddrInputs = !con((ins SReg_128:$srsrc), SOffset, + dag NonVaddrInputs = !con((ins SReg_128_XNULL:$srsrc), SOffset, (ins Offset:$offset, FORMAT:$format, CPol_0:$cpol, i1imm_0:$swz)); dag Inputs = !if(!empty(vaddrList), @@ -418,7 +418,7 @@ class getMUBUFInsDA<list<RegisterClass> vdataList, RegisterOperand vdata_op = getLdStVDataRegisterOperand<vdataClass, isTFE>.ret; dag SOffset = !if(hasRestrictedSOffset, (ins SReg_32:$soffset), (ins SCSrc_b32:$soffset)); - dag NonVaddrInputs = !con((ins SReg_128:$srsrc), SOffset, (ins Offset:$offset, CPol_0:$cpol, i1imm_0:$swz)); + dag NonVaddrInputs = !con((ins SReg_128_XNULL:$srsrc), SOffset, (ins Offset:$offset, CPol_0:$cpol, i1imm_0:$swz)); dag Inputs = !if(!empty(vaddrList), NonVaddrInputs, !con((ins vaddrClass:$vaddr), NonVaddrInputs)); dag ret = !if(!empty(vdataList), Inputs, !con((ins vdata_op:$vdata), Inputs)); @@ -680,7 +680,7 @@ multiclass MUBUF_Pseudo_Stores<string opName, ValueType store_vt = i32> { class MUBUF_Pseudo_Store_Lds<string opName> : MUBUF_Pseudo<opName, (outs), - (ins SReg_128:$srsrc, SCSrc_b32:$soffset, Offset:$offset, CPol:$cpol, i1imm:$swz), + (ins SReg_128_XNULL:$srsrc, SCSrc_b32:$soffset, Offset:$offset, CPol:$cpol, i1imm:$swz), " $srsrc, $soffset$offset lds$cpol"> { let LGKM_CNT = 1; let mayLoad = 1; @@ -703,7 +703,7 @@ class getMUBUFAtomicInsDA<RegisterClass vdataClass, bit vdata_in, bit hasRestric dag VData = !if(vdata_in, (ins vdata_op:$vdata_in), (ins vdata_op:$vdata)); dag Data = !if(!empty(vaddrList), VData, !con(VData, (ins vaddrClass:$vaddr))); dag SOffset = !if(hasRestrictedSOffset, (ins SReg_32:$soffset), (ins SCSrc_b32:$soffset)); - dag MainInputs = !con((ins SReg_128:$srsrc), SOffset, (ins Offset:$offset)); + dag MainInputs = !con((ins SReg_128_XNULL:$srsrc), SOffset, (ins Offset:$offset)); dag CPol = !if(vdata_in, (ins CPol_GLC_WithDefault:$cpol), (ins CPol_NonGLC_WithDefault:$cpol)); diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp index 5908351..d236327 100644 --- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp +++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp @@ -279,7 +279,9 @@ DECODE_OPERAND_REG_7(SReg_64_XEXEC, OPW64) DECODE_OPERAND_REG_7(SReg_64_XEXEC_XNULL, OPW64) DECODE_OPERAND_REG_7(SReg_96, OPW96) DECODE_OPERAND_REG_7(SReg_128, OPW128) +DECODE_OPERAND_REG_7(SReg_128_XNULL, OPW128) DECODE_OPERAND_REG_7(SReg_256, OPW256) +DECODE_OPERAND_REG_7(SReg_256_XNULL, OPW256) DECODE_OPERAND_REG_7(SReg_512, OPW512) DECODE_OPERAND_REG_8(AGPR_32) @@ -1692,6 +1694,11 @@ AMDGPUDisassembler::decodeNonVGPRSrcOp(const OpWidthTy Width, unsigned Val, case OPW64: case OPWV232: return decodeSpecialReg64(Val); + case OPW96: + case OPW128: + case OPW256: + case OPW512: + return decodeSpecialReg96Plus(Val); default: llvm_unreachable("unexpected immediate type"); } @@ -1778,6 +1785,24 @@ MCOperand AMDGPUDisassembler::decodeSpecialReg64(unsigned Val) const { return errOperand(Val, "unknown operand encoding " + Twine(Val)); } +MCOperand AMDGPUDisassembler::decodeSpecialReg96Plus(unsigned Val) const { + using namespace AMDGPU; + + switch (Val) { + case 124: + if (isGFX11Plus()) + return createRegOperand(SGPR_NULL); + break; + case 125: + if (!isGFX11Plus()) + return createRegOperand(SGPR_NULL); + break; + default: + break; + } + return errOperand(Val, "unknown operand encoding " + Twine(Val)); +} + MCOperand AMDGPUDisassembler::decodeSDWASrc(const OpWidthTy Width, const unsigned Val, unsigned ImmWidth, diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h index b19e4b7..9a06cc3 100644 --- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h +++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h @@ -259,6 +259,7 @@ public: MCOperand decodeVOPDDstYOp(MCInst &Inst, unsigned Val) const; MCOperand decodeSpecialReg32(unsigned Val) const; MCOperand decodeSpecialReg64(unsigned Val) const; + MCOperand decodeSpecialReg96Plus(unsigned Val) const; MCOperand decodeSDWASrc(const OpWidthTy Width, unsigned Val, unsigned ImmWidth, diff --git a/llvm/lib/Target/AMDGPU/MIMGInstructions.td b/llvm/lib/Target/AMDGPU/MIMGInstructions.td index 4722d33..1b94d6c 100644 --- a/llvm/lib/Target/AMDGPU/MIMGInstructions.td +++ b/llvm/lib/Target/AMDGPU/MIMGInstructions.td @@ -422,7 +422,7 @@ class MIMG_NoSampler_Helper <mimgopc op, string asm, RegisterClass addr_rc, string dns=""> : MIMG_gfx6789 <op.GFX10M, (outs dst_rc:$vdata), dns> { - let InOperandList = !con((ins addr_rc:$vaddr, SReg_256:$srsrc, + let InOperandList = !con((ins addr_rc:$vaddr, SReg_256_XNULL:$srsrc, DMask:$dmask, UNorm:$unorm, CPol:$cpol, R128A16:$r128, TFE:$tfe, LWE:$lwe, DA:$da), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -435,7 +435,7 @@ class MIMG_NoSampler_Helper_gfx90a <mimgopc op, string asm, RegisterClass addr_rc, string dns=""> : MIMG_gfx90a <op.GFX10M, (outs getLdStRegisterOperand<dst_rc>.ret:$vdata), dns> { - let InOperandList = !con((ins addr_rc:$vaddr, SReg_256:$srsrc, + let InOperandList = !con((ins addr_rc:$vaddr, SReg_256_XNULL:$srsrc, DMask:$dmask, UNorm:$unorm, CPol:$cpol, R128A16:$r128, LWE:$lwe, DA:$da), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -447,7 +447,7 @@ class MIMG_NoSampler_gfx10<mimgopc op, string opcode, RegisterClass DataRC, RegisterClass AddrRC, string dns=""> : MIMG_gfx10<op.GFX10M, (outs DataRC:$vdata), dns> { - let InOperandList = !con((ins AddrRC:$vaddr0, SReg_256:$srsrc, DMask:$dmask, + let InOperandList = !con((ins AddrRC:$vaddr0, SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -460,7 +460,7 @@ class MIMG_NoSampler_nsa_gfx10<mimgopc op, string opcode, string dns=""> : MIMG_nsa_gfx10<op.GFX10M, (outs DataRC:$vdata), num_addrs, dns> { let InOperandList = !con(AddrIns, - (ins SReg_256:$srsrc, DMask:$dmask, + (ins SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -472,7 +472,7 @@ class MIMG_NoSampler_gfx11<mimgopc op, string opcode, RegisterClass DataRC, RegisterClass AddrRC, string dns=""> : MIMG_gfx11<op.GFX11, (outs DataRC:$vdata), dns> { - let InOperandList = !con((ins AddrRC:$vaddr0, SReg_256:$srsrc, DMask:$dmask, + let InOperandList = !con((ins AddrRC:$vaddr0, SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -485,7 +485,7 @@ class MIMG_NoSampler_nsa_gfx11<mimgopc op, string opcode, string dns=""> : MIMG_nsa_gfx11<op.GFX11, (outs DataRC:$vdata), num_addrs, dns> { let InOperandList = !con(AddrIns, - (ins SReg_256:$srsrc, DMask:$dmask, + (ins SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -498,7 +498,7 @@ class VIMAGE_NoSampler_gfx12<mimgopc op, string opcode, string dns=""> : VIMAGE_gfx12<op.GFX12, (outs DataRC:$vdata), num_addrs, dns> { let InOperandList = !con(AddrIns, - (ins SReg_256:$rsrc, DMask:$dmask, Dim:$dim, + (ins SReg_256_XNULL:$rsrc, DMask:$dmask, Dim:$dim, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); let AsmString = opcode#" $vdata, "#AddrAsm#", $rsrc$dmask$dim$cpol$r128$a16$tfe" @@ -510,8 +510,8 @@ class VSAMPLE_Sampler_gfx12<mimgopc op, string opcode, RegisterClass DataRC, string dns=""> : VSAMPLE_gfx12<op.GFX12, (outs DataRC:$vdata), num_addrs, dns, Addr3RC> { let InOperandList = !con(AddrIns, - (ins SReg_256:$rsrc), - !if(BaseOpcode.Sampler, (ins SReg_128:$samp), (ins)), + (ins SReg_256_XNULL:$rsrc), + !if(BaseOpcode.Sampler, (ins SReg_128_XNULL:$samp), (ins)), (ins DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), @@ -527,8 +527,8 @@ class VSAMPLE_Sampler_nortn_gfx12<mimgopc op, string opcode, string dns=""> : VSAMPLE_gfx12<op.GFX12, (outs), num_addrs, dns, Addr3RC> { let InOperandList = !con(AddrIns, - (ins SReg_256:$rsrc), - !if(BaseOpcode.Sampler, (ins SReg_128:$samp), (ins)), + (ins SReg_256_XNULL:$rsrc), + !if(BaseOpcode.Sampler, (ins SReg_128_XNULL:$samp), (ins)), (ins DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), @@ -679,7 +679,7 @@ class MIMG_Store_Helper <mimgopc op, string asm, RegisterClass addr_rc, string dns = ""> : MIMG_gfx6789<op.GFX10M, (outs), dns> { - let InOperandList = !con((ins data_rc:$vdata, addr_rc:$vaddr, SReg_256:$srsrc, + let InOperandList = !con((ins data_rc:$vdata, addr_rc:$vaddr, SReg_256_XNULL:$srsrc, DMask:$dmask, UNorm:$unorm, CPol:$cpol, R128A16:$r128, TFE:$tfe, LWE:$lwe, DA:$da), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -693,7 +693,7 @@ class MIMG_Store_Helper_gfx90a <mimgopc op, string asm, string dns = ""> : MIMG_gfx90a<op.GFX10M, (outs), dns> { let InOperandList = !con((ins getLdStRegisterOperand<data_rc>.ret:$vdata, - addr_rc:$vaddr, SReg_256:$srsrc, + addr_rc:$vaddr, SReg_256_XNULL:$srsrc, DMask:$dmask, UNorm:$unorm, CPol:$cpol, R128A16:$r128, LWE:$lwe, DA:$da), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -705,7 +705,7 @@ class MIMG_Store_gfx10<mimgopc op, string opcode, RegisterClass DataRC, RegisterClass AddrRC, string dns=""> : MIMG_gfx10<op.GFX10M, (outs), dns> { - let InOperandList = !con((ins DataRC:$vdata, AddrRC:$vaddr0, SReg_256:$srsrc, + let InOperandList = !con((ins DataRC:$vdata, AddrRC:$vaddr0, SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -719,7 +719,7 @@ class MIMG_Store_nsa_gfx10<mimgopc op, string opcode, : MIMG_nsa_gfx10<op.GFX10M, (outs), num_addrs, dns> { let InOperandList = !con((ins DataRC:$vdata), AddrIns, - (ins SReg_256:$srsrc, DMask:$dmask, + (ins SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -731,7 +731,7 @@ class MIMG_Store_gfx11<mimgopc op, string opcode, RegisterClass DataRC, RegisterClass AddrRC, string dns=""> : MIMG_gfx11<op.GFX11, (outs), dns> { - let InOperandList = !con((ins DataRC:$vdata, AddrRC:$vaddr0, SReg_256:$srsrc, + let InOperandList = !con((ins DataRC:$vdata, AddrRC:$vaddr0, SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -745,7 +745,7 @@ class MIMG_Store_nsa_gfx11<mimgopc op, string opcode, : MIMG_nsa_gfx11<op.GFX11, (outs), num_addrs, dns> { let InOperandList = !con((ins DataRC:$vdata), AddrIns, - (ins SReg_256:$srsrc, DMask:$dmask, + (ins SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -759,7 +759,7 @@ class VIMAGE_Store_gfx12<mimgopc op, string opcode, : VIMAGE_gfx12<op.GFX12, (outs), num_addrs, dns> { let InOperandList = !con((ins DataRC:$vdata), AddrIns, - (ins SReg_256:$rsrc, DMask:$dmask, Dim:$dim, + (ins SReg_256_XNULL:$rsrc, DMask:$dmask, Dim:$dim, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); let AsmString = opcode#" $vdata, "#AddrAsm#", $rsrc$dmask$dim$cpol$r128$a16$tfe" @@ -875,7 +875,7 @@ class MIMG_Atomic_gfx6789_base <bits<8> op, string asm, RegisterClass data_rc, : MIMG_gfx6789 <op, (outs data_rc:$vdst), dns> { let Constraints = "$vdst = $vdata"; - let InOperandList = (ins data_rc:$vdata, addr_rc:$vaddr, SReg_256:$srsrc, + let InOperandList = (ins data_rc:$vdata, addr_rc:$vaddr, SReg_256_XNULL:$srsrc, DMask:$dmask, UNorm:$unorm, CPol:$cpol, R128A16:$r128, TFE:$tfe, LWE:$lwe, DA:$da); let AsmString = asm#" $vdst, $vaddr, $srsrc$dmask$unorm$cpol$r128$tfe$lwe$da"; @@ -887,7 +887,7 @@ class MIMG_Atomic_gfx90a_base <bits<8> op, string asm, RegisterClass data_rc, let Constraints = "$vdst = $vdata"; let InOperandList = (ins getLdStRegisterOperand<data_rc>.ret:$vdata, - addr_rc:$vaddr, SReg_256:$srsrc, + addr_rc:$vaddr, SReg_256_XNULL:$srsrc, DMask:$dmask, UNorm:$unorm, CPol:$cpol, R128A16:$r128, LWE:$lwe, DA:$da); let AsmString = asm#" $vdst, $vaddr, $srsrc$dmask$unorm$cpol$r128$lwe$da"; @@ -921,7 +921,7 @@ class MIMG_Atomic_gfx10<mimgopc op, string opcode, !if(enableDisasm, "GFX10", "")> { let Constraints = "$vdst = $vdata"; - let InOperandList = (ins DataRC:$vdata, AddrRC:$vaddr0, SReg_256:$srsrc, + let InOperandList = (ins DataRC:$vdata, AddrRC:$vaddr0, SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe); let AsmString = opcode#" $vdst, $vaddr0, $srsrc$dmask$dim$unorm$cpol$r128$a16$tfe$lwe"; @@ -936,7 +936,7 @@ class MIMG_Atomic_nsa_gfx10<mimgopc op, string opcode, let InOperandList = !con((ins DataRC:$vdata), AddrIns, - (ins SReg_256:$srsrc, DMask:$dmask, + (ins SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe)); let AsmString = opcode#" $vdata, "#AddrAsm#", $srsrc$dmask$dim$unorm$cpol$r128$a16$tfe$lwe"; @@ -949,7 +949,7 @@ class MIMG_Atomic_gfx11<mimgopc op, string opcode, !if(enableDisasm, "GFX11", "")> { let Constraints = "$vdst = $vdata"; - let InOperandList = (ins DataRC:$vdata, AddrRC:$vaddr0, SReg_256:$srsrc, + let InOperandList = (ins DataRC:$vdata, AddrRC:$vaddr0, SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe); let AsmString = opcode#" $vdst, $vaddr0, $srsrc$dmask$dim$unorm$cpol$r128$a16$tfe$lwe"; @@ -964,7 +964,7 @@ class MIMG_Atomic_nsa_gfx11<mimgopc op, string opcode, let InOperandList = !con((ins DataRC:$vdata), AddrIns, - (ins SReg_256:$srsrc, DMask:$dmask, + (ins SReg_256_XNULL:$srsrc, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe)); let AsmString = opcode#" $vdata, "#AddrAsm#", $srsrc$dmask$dim$unorm$cpol$r128$a16$tfe$lwe"; @@ -978,7 +978,7 @@ class VIMAGE_Atomic_gfx12<mimgopc op, string opcode, RegisterClass DataRC, let InOperandList = !con((ins DataRC:$vdata), AddrIns, - (ins SReg_256:$rsrc, DMask:$dmask, Dim:$dim, + (ins SReg_256_XNULL:$rsrc, DMask:$dmask, Dim:$dim, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe)); let AsmString = !if(!empty(renamed), opcode, renamed)#" $vdata, "#AddrAsm# ", $rsrc$dmask$dim$cpol$r128$a16$tfe"; @@ -1128,7 +1128,7 @@ multiclass MIMG_Atomic_Renamed <mimgopc op, string asm, string renamed, class MIMG_Sampler_Helper <mimgopc op, string asm, RegisterClass dst_rc, RegisterClass src_rc, string dns=""> : MIMG_gfx6789 <op.VI, (outs dst_rc:$vdata), dns> { - let InOperandList = !con((ins src_rc:$vaddr, SReg_256:$srsrc, SReg_128:$ssamp, + let InOperandList = !con((ins src_rc:$vaddr, SReg_256_XNULL:$srsrc, SReg_128_XNULL:$ssamp, DMask:$dmask, UNorm:$unorm, CPol:$cpol, R128A16:$r128, TFE:$tfe, LWE:$lwe, DA:$da), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -1139,7 +1139,7 @@ class MIMG_Sampler_Helper <mimgopc op, string asm, RegisterClass dst_rc, class MIMG_Sampler_gfx90a<mimgopc op, string asm, RegisterClass dst_rc, RegisterClass src_rc, string dns=""> : MIMG_gfx90a<op.GFX10M, (outs getLdStRegisterOperand<dst_rc>.ret:$vdata), dns> { - let InOperandList = !con((ins src_rc:$vaddr, SReg_256:$srsrc, SReg_128:$ssamp, + let InOperandList = !con((ins src_rc:$vaddr, SReg_256_XNULL:$srsrc, SReg_128_XNULL:$ssamp, DMask:$dmask, UNorm:$unorm, CPol:$cpol, R128A16:$r128, LWE:$lwe, DA:$da), !if(BaseOpcode.HasD16, (ins D16:$d16), (ins))); @@ -1149,7 +1149,7 @@ class MIMG_Sampler_gfx90a<mimgopc op, string asm, RegisterClass dst_rc, class MIMG_Sampler_OpList_gfx10p<dag OpPrefix, bit HasD16> { dag ret = !con(OpPrefix, - (ins SReg_256:$srsrc, SReg_128:$ssamp, + (ins SReg_256_XNULL:$srsrc, SReg_128_XNULL:$ssamp, DMask:$dmask, Dim:$dim, UNorm:$unorm, CPol:$cpol, R128A16:$r128, A16:$a16, TFE:$tfe, LWE:$lwe), !if(HasD16, (ins D16:$d16), (ins))); @@ -1524,7 +1524,7 @@ class MIMG_IntersectRay_Helper<bit Is64, bit IsA16> { class MIMG_IntersectRay_gfx10<mimgopc op, string opcode, RegisterClass AddrRC> : MIMG_gfx10<op.GFX10M, (outs VReg_128:$vdata), "GFX10"> { - let InOperandList = (ins AddrRC:$vaddr0, SReg_128:$srsrc, A16:$a16); + let InOperandList = (ins AddrRC:$vaddr0, SReg_128_XNULL:$srsrc, A16:$a16); let AsmString = opcode#" $vdata, $vaddr0, $srsrc$a16"; let nsa = 0; @@ -1532,13 +1532,13 @@ class MIMG_IntersectRay_gfx10<mimgopc op, string opcode, RegisterClass AddrRC> class MIMG_IntersectRay_nsa_gfx10<mimgopc op, string opcode, int num_addrs> : MIMG_nsa_gfx10<op.GFX10M, (outs VReg_128:$vdata), num_addrs, "GFX10"> { - let InOperandList = !con(nsah.AddrIns, (ins SReg_128:$srsrc, A16:$a16)); + let InOperandList = !con(nsah.AddrIns, (ins SReg_128_XNULL:$srsrc, A16:$a16)); let AsmString = opcode#" $vdata, "#nsah.AddrAsm#", $srsrc$a16"; } class MIMG_IntersectRay_gfx11<mimgopc op, string opcode, RegisterClass AddrRC> : MIMG_gfx11<op.GFX11, (outs VReg_128:$vdata), "GFX11"> { - let InOperandList = (ins AddrRC:$vaddr0, SReg_128:$srsrc, A16:$a16); + let InOperandList = (ins AddrRC:$vaddr0, SReg_128_XNULL:$srsrc, A16:$a16); let AsmString = opcode#" $vdata, $vaddr0, $srsrc$a16"; let nsa = 0; @@ -1548,7 +1548,7 @@ class MIMG_IntersectRay_nsa_gfx11<mimgopc op, string opcode, int num_addrs, list<RegisterClass> addr_types> : MIMG_nsa_gfx11<op.GFX11, (outs VReg_128:$vdata), num_addrs, "GFX11", addr_types> { - let InOperandList = !con(nsah.AddrIns, (ins SReg_128:$srsrc, A16:$a16)); + let InOperandList = !con(nsah.AddrIns, (ins SReg_128_XNULL:$srsrc, A16:$a16)); let AsmString = opcode#" $vdata, "#nsah.AddrAsm#", $srsrc$a16"; } @@ -1556,7 +1556,7 @@ class VIMAGE_IntersectRay_gfx12<mimgopc op, string opcode, int num_addrs, list<RegisterClass> addr_types> : VIMAGE_gfx12<op.GFX12, (outs VReg_128:$vdata), num_addrs, "GFX12", addr_types> { - let InOperandList = !con(nsah.AddrIns, (ins SReg_128:$rsrc, A16:$a16)); + let InOperandList = !con(nsah.AddrIns, (ins SReg_128_XNULL:$rsrc, A16:$a16)); let AsmString = opcode#" $vdata, "#nsah.AddrAsm#", $rsrc$a16"; } diff --git a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp index 4fb5cb0..2bc1913 100644 --- a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp +++ b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp @@ -199,7 +199,7 @@ static unsigned macToMad(unsigned Opc) { case AMDGPU::V_FMAC_F16_e64: return AMDGPU::V_FMA_F16_gfx9_e64; case AMDGPU::V_FMAC_F16_fake16_e64: - return AMDGPU::V_FMA_F16_gfx9_e64; + return AMDGPU::V_FMA_F16_gfx9_fake16_e64; case AMDGPU::V_FMAC_LEGACY_F32_e64: return AMDGPU::V_FMA_LEGACY_F32_e64; case AMDGPU::V_FMAC_F64_e64: @@ -1096,21 +1096,8 @@ void SIFoldOperandsImpl::foldOperand( B.addImm(Defs[I].second); } LLVM_DEBUG(dbgs() << "Folded " << *UseMI); - return; } - if (Size != 4) - return; - - Register Reg0 = UseMI->getOperand(0).getReg(); - Register Reg1 = UseMI->getOperand(1).getReg(); - if (TRI->isAGPR(*MRI, Reg0) && TRI->isVGPR(*MRI, Reg1)) - UseMI->setDesc(TII->get(AMDGPU::V_ACCVGPR_WRITE_B32_e64)); - else if (TRI->isVGPR(*MRI, Reg0) && TRI->isAGPR(*MRI, Reg1)) - UseMI->setDesc(TII->get(AMDGPU::V_ACCVGPR_READ_B32_e64)); - else if (ST->hasGFX90AInsts() && TRI->isAGPR(*MRI, Reg0) && - TRI->isAGPR(*MRI, Reg1)) - UseMI->setDesc(TII->get(AMDGPU::V_ACCVGPR_MOV_B32)); return; } diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 58b061f..0ac84f4 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -4017,29 +4017,26 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI, } // This is similar to the default implementation in ExpandDYNAMIC_STACKALLOC, -// except for stack growth direction(default: downwards, AMDGPU: upwards) and -// applying the wave size scale to the increment amount. -SDValue SITargetLowering::lowerDYNAMIC_STACKALLOCImpl(SDValue Op, - SelectionDAG &DAG) const { +// except for: +// 1. Stack growth direction(default: downwards, AMDGPU: upwards), and +// 2. Scale size where, scale = wave-reduction(alloca-size) * wave-size +SDValue SITargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, + SelectionDAG &DAG) const { const MachineFunction &MF = DAG.getMachineFunction(); const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>(); SDLoc dl(Op); EVT VT = Op.getValueType(); - SDValue Tmp1 = Op; - SDValue Tmp2 = Op.getValue(1); - SDValue Tmp3 = Op.getOperand(2); - SDValue Chain = Tmp1.getOperand(0); - + SDValue Chain = Op.getOperand(0); Register SPReg = Info->getStackPtrOffsetReg(); // Chain the dynamic stack allocation so that it doesn't modify the stack // pointer when other instructions are using the stack. Chain = DAG.getCALLSEQ_START(Chain, 0, 0, dl); - SDValue Size = Tmp2.getOperand(1); + SDValue Size = Op.getOperand(1); SDValue BaseAddr = DAG.getCopyFromReg(Chain, dl, SPReg, VT); - Align Alignment = cast<ConstantSDNode>(Tmp3)->getAlignValue(); + Align Alignment = cast<ConstantSDNode>(Op.getOperand(2))->getAlignValue(); const TargetFrameLowering *TFL = Subtarget->getFrameLowering(); assert(TFL->getStackGrowthDirection() == TargetFrameLowering::StackGrowsUp && @@ -4057,30 +4054,36 @@ SDValue SITargetLowering::lowerDYNAMIC_STACKALLOCImpl(SDValue Op, DAG.getSignedConstant(-ScaledAlignment, dl, VT)); } - SDValue ScaledSize = DAG.getNode( - ISD::SHL, dl, VT, Size, - DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32)); - - SDValue NewSP = DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value + assert(Size.getValueType() == MVT::i32 && "Size must be 32-bit"); + SDValue NewSP; + if (isa<ConstantSDNode>(Size)) { + // For constant sized alloca, scale alloca size by wave-size + SDValue ScaledSize = DAG.getNode( + ISD::SHL, dl, VT, Size, + DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32)); + NewSP = DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value + } else { + // For dynamic sized alloca, perform wave-wide reduction to get max of + // alloca size(divergent) and then scale it by wave-size + SDValue WaveReduction = + DAG.getTargetConstant(Intrinsic::amdgcn_wave_reduce_umax, dl, MVT::i32); + Size = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::i32, WaveReduction, + Size, DAG.getConstant(0, dl, MVT::i32)); + SDValue ScaledSize = DAG.getNode( + ISD::SHL, dl, VT, Size, + DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32)); + NewSP = + DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value in vgpr. + SDValue ReadFirstLaneID = + DAG.getTargetConstant(Intrinsic::amdgcn_readfirstlane, dl, MVT::i32); + NewSP = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::i32, ReadFirstLaneID, + NewSP); + } Chain = DAG.getCopyToReg(Chain, dl, SPReg, NewSP); // Output chain - Tmp2 = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl); + SDValue CallSeqEnd = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl); - return DAG.getMergeValues({BaseAddr, Tmp2}, dl); -} - -SDValue SITargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, - SelectionDAG &DAG) const { - // We only handle constant sizes here to allow non-entry block, static sized - // allocas. A truly dynamic value is more difficult to support because we - // don't know if the size value is uniform or not. If the size isn't uniform, - // we would need to do a wave reduction to get the maximum size to know how - // much to increment the uniform stack pointer. - SDValue Size = Op.getOperand(1); - if (isa<ConstantSDNode>(Size)) - return lowerDYNAMIC_STACKALLOCImpl(Op, DAG); // Use "generic" expansion. - - return AMDGPUTargetLowering::LowerDYNAMIC_STACKALLOC(Op, DAG); + return DAG.getMergeValues({BaseAddr, CallSeqEnd}, dl); } SDValue SITargetLowering::LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const { @@ -13982,6 +13985,43 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N, return Accum; } +SDValue +SITargetLowering::foldAddSub64WithZeroLowBitsTo32(SDNode *N, + DAGCombinerInfo &DCI) const { + SDValue RHS = N->getOperand(1); + auto *CRHS = dyn_cast<ConstantSDNode>(RHS); + if (!CRHS) + return SDValue(); + + // TODO: Worth using computeKnownBits? Maybe expensive since it's so + // common. + uint64_t Val = CRHS->getZExtValue(); + if (countr_zero(Val) >= 32) { + SelectionDAG &DAG = DCI.DAG; + SDLoc SL(N); + SDValue LHS = N->getOperand(0); + + // Avoid carry machinery if we know the low half of the add does not + // contribute to the final result. + // + // add i64:x, K if computeTrailingZeros(K) >= 32 + // => build_pair (add x.hi, K.hi), x.lo + + // Breaking the 64-bit add here with this strange constant is unlikely + // to interfere with addressing mode patterns. + + SDValue Hi = getHiHalf64(LHS, DAG); + SDValue ConstHi32 = DAG.getConstant(Hi_32(Val), SL, MVT::i32); + SDValue AddHi = + DAG.getNode(N->getOpcode(), SL, MVT::i32, Hi, ConstHi32, N->getFlags()); + + SDValue Lo = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS); + return DAG.getNode(ISD::BUILD_PAIR, SL, MVT::i64, Lo, AddHi); + } + + return SDValue(); +} + // Collect the ultimate src of each of the mul node's operands, and confirm // each operand is 8 bytes. static std::optional<ByteProvider<SDValue>> @@ -14258,6 +14298,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, return V; } + if (VT == MVT::i64) { + if (SDValue Folded = foldAddSub64WithZeroLowBitsTo32(N, DCI)) + return Folded; + } + if ((isMul(LHS) || isMul(RHS)) && Subtarget->hasDot7Insts() && (Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) { SDValue TempNode(N, 0); @@ -14443,6 +14488,11 @@ SDValue SITargetLowering::performSubCombine(SDNode *N, SelectionDAG &DAG = DCI.DAG; EVT VT = N->getValueType(0); + if (VT == MVT::i64) { + if (SDValue Folded = foldAddSub64WithZeroLowBitsTo32(N, DCI)) + return Folded; + } + if (VT != MVT::i32) return SDValue(); diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h index 631f265..299c8f5 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.h +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h @@ -212,6 +212,9 @@ private: unsigned getFusedOpcode(const SelectionDAG &DAG, const SDNode *N0, const SDNode *N1) const; SDValue tryFoldToMad64_32(SDNode *N, DAGCombinerInfo &DCI) const; + SDValue foldAddSub64WithZeroLowBitsTo32(SDNode *N, + DAGCombinerInfo &DCI) const; + SDValue performAddCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performAddCarrySubCarryCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performSubCombine(SDNode *N, DAGCombinerInfo &DCI) const; @@ -421,7 +424,6 @@ public: SDValue LowerCall(CallLoweringInfo &CLI, SmallVectorImpl<SDValue> &InVals) const override; - SDValue lowerDYNAMIC_STACKALLOCImpl(SDValue Op, SelectionDAG &DAG) const; SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const; SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index f97ea40..e6f333f 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -3805,6 +3805,36 @@ static void updateLiveVariables(LiveVariables *LV, MachineInstr &MI, } } +static unsigned getNewFMAInst(const GCNSubtarget &ST, unsigned Opc) { + switch (Opc) { + case AMDGPU::V_MAC_F16_e32: + case AMDGPU::V_MAC_F16_e64: + return AMDGPU::V_MAD_F16_e64; + case AMDGPU::V_MAC_F32_e32: + case AMDGPU::V_MAC_F32_e64: + return AMDGPU::V_MAD_F32_e64; + case AMDGPU::V_MAC_LEGACY_F32_e32: + case AMDGPU::V_MAC_LEGACY_F32_e64: + return AMDGPU::V_MAD_LEGACY_F32_e64; + case AMDGPU::V_FMAC_LEGACY_F32_e32: + case AMDGPU::V_FMAC_LEGACY_F32_e64: + return AMDGPU::V_FMA_LEGACY_F32_e64; + case AMDGPU::V_FMAC_F16_e32: + case AMDGPU::V_FMAC_F16_e64: + case AMDGPU::V_FMAC_F16_fake16_e64: + return ST.hasTrue16BitInsts() ? AMDGPU::V_FMA_F16_gfx9_fake16_e64 + : AMDGPU::V_FMA_F16_gfx9_e64; + case AMDGPU::V_FMAC_F32_e32: + case AMDGPU::V_FMAC_F32_e64: + return AMDGPU::V_FMA_F32_e64; + case AMDGPU::V_FMAC_F64_e32: + case AMDGPU::V_FMAC_F64_e64: + return AMDGPU::V_FMA_F64_e64; + default: + llvm_unreachable("invalid instruction"); + } +} + MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, LiveVariables *LV, LiveIntervals *LIS) const { @@ -4040,14 +4070,8 @@ MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, if (Src0Literal && !ST.hasVOP3Literal()) return nullptr; - unsigned NewOpc = IsFMA ? IsF16 ? AMDGPU::V_FMA_F16_gfx9_e64 - : IsF64 ? AMDGPU::V_FMA_F64_e64 - : IsLegacy - ? AMDGPU::V_FMA_LEGACY_F32_e64 - : AMDGPU::V_FMA_F32_e64 - : IsF16 ? AMDGPU::V_MAD_F16_e64 - : IsLegacy ? AMDGPU::V_MAD_LEGACY_F32_e64 - : AMDGPU::V_MAD_F32_e64; + unsigned NewOpc = getNewFMAInst(ST, Opc); + if (pseudoToMCOpcode(NewOpc) == -1) return nullptr; @@ -6866,9 +6890,8 @@ SIInstrInfo::legalizeOperands(MachineInstr &MI, AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::srsrc); if (RsrcIdx != -1) { MachineOperand *Rsrc = &MI.getOperand(RsrcIdx); - if (Rsrc->isReg() && !RI.isSGPRClass(MRI.getRegClass(Rsrc->getReg()))) { + if (Rsrc->isReg() && !RI.isSGPRReg(MRI, Rsrc->getReg())) isRsrcLegal = false; - } } // The operands are legal. @@ -9294,6 +9317,7 @@ static bool isRenamedInGFX9(int Opcode) { case AMDGPU::V_DIV_FIXUP_F16_gfx9_e64: case AMDGPU::V_DIV_FIXUP_F16_gfx9_fake16_e64: case AMDGPU::V_FMA_F16_gfx9_e64: + case AMDGPU::V_FMA_F16_gfx9_fake16_e64: case AMDGPU::V_INTERP_P2_F16: case AMDGPU::V_MAD_F16_e64: case AMDGPU::V_MAD_U16_e64: diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td index 789ce88..ee83dff 100644 --- a/llvm/lib/Target/AMDGPU/SIInstructions.td +++ b/llvm/lib/Target/AMDGPU/SIInstructions.td @@ -2674,8 +2674,8 @@ let OtherPredicates = [NotHasTrue16BitInsts] in { } // end OtherPredicates = [NotHasTrue16BitInsts] let OtherPredicates = [HasTrue16BitInsts] in { - def : FPToI1Pat<V_CMP_EQ_F16_t16_e64, CONST.FP16_ONE, i16, f16, fp_to_uint>; - def : FPToI1Pat<V_CMP_EQ_F16_t16_e64, CONST.FP16_NEG_ONE, i16, f16, fp_to_sint>; + def : FPToI1Pat<V_CMP_EQ_F16_fake16_e64, CONST.FP16_ONE, i16, f16, fp_to_uint>; + def : FPToI1Pat<V_CMP_EQ_F16_fake16_e64, CONST.FP16_NEG_ONE, i16, f16, fp_to_sint>; } // end OtherPredicates = [HasTrue16BitInsts] def : FPToI1Pat<V_CMP_EQ_F32_e64, CONST.FP32_ONE, i32, f32, fp_to_uint>; @@ -3055,7 +3055,7 @@ def : GCNPat< (V_BFREV_B32_e64 (i32 (EXTRACT_SUBREG VReg_64:$a, sub0))), sub1)>; // If fcanonicalize's operand is implicitly canonicalized, we only need a copy. -let AddedComplexity = 1000 in { +let AddedComplexity = 8 in { foreach vt = [f16, v2f16, f32, v2f32, f64] in { def : GCNPat< (fcanonicalize (vt is_canonicalized:$src)), @@ -3710,12 +3710,15 @@ def : IntMinMaxPat<V_MAXMIN_U32_e64, umin, umax_oneuse>; def : IntMinMaxPat<V_MINMAX_U32_e64, umax, umin_oneuse>; def : FPMinMaxPat<V_MINMAX_F32_e64, f32, fmaxnum_like, fminnum_like_oneuse>; def : FPMinMaxPat<V_MAXMIN_F32_e64, f32, fminnum_like, fmaxnum_like_oneuse>; -def : FPMinMaxPat<V_MINMAX_F16_e64, f16, fmaxnum_like, fminnum_like_oneuse>; -def : FPMinMaxPat<V_MAXMIN_F16_e64, f16, fminnum_like, fmaxnum_like_oneuse>; def : FPMinCanonMaxPat<V_MINMAX_F32_e64, f32, fmaxnum_like, fminnum_like_oneuse>; def : FPMinCanonMaxPat<V_MAXMIN_F32_e64, f32, fminnum_like, fmaxnum_like_oneuse>; -def : FPMinCanonMaxPat<V_MINMAX_F16_e64, f16, fmaxnum_like, fminnum_like_oneuse>; -def : FPMinCanonMaxPat<V_MAXMIN_F16_e64, f16, fminnum_like, fmaxnum_like_oneuse>; +} + +let True16Predicate = UseFakeTrue16Insts in { +def : FPMinMaxPat<V_MINMAX_F16_fake16_e64, f16, fmaxnum_like, fminnum_like_oneuse>; +def : FPMinMaxPat<V_MAXMIN_F16_fake16_e64, f16, fminnum_like, fmaxnum_like_oneuse>; +def : FPMinCanonMaxPat<V_MINMAX_F16_fake16_e64, f16, fmaxnum_like, fminnum_like_oneuse>; +def : FPMinCanonMaxPat<V_MAXMIN_F16_fake16_e64, f16, fminnum_like, fmaxnum_like_oneuse>; } let SubtargetPredicate = isGFX9Plus in { @@ -3723,6 +3726,10 @@ let True16Predicate = NotHasTrue16BitInsts in { defm : Int16Med3Pat<V_MED3_I16_e64, smin, smax, VSrc_b16>; defm : Int16Med3Pat<V_MED3_U16_e64, umin, umax, VSrc_b16>; } +let True16Predicate = UseRealTrue16Insts in { + defm : Int16Med3Pat<V_MED3_I16_t16_e64, smin, smax, VSrcT_b16>; + defm : Int16Med3Pat<V_MED3_U16_t16_e64, umin, umax, VSrcT_b16>; +} let True16Predicate = UseFakeTrue16Insts in { defm : Int16Med3Pat<V_MED3_I16_fake16_e64, smin, smax, VSrc_b16>; defm : Int16Med3Pat<V_MED3_U16_fake16_e64, umin, umax, VSrc_b16>; diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td index 16a7a9c..7c98ccd 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td @@ -153,14 +153,16 @@ class SIRegisterClass <string n, list<ValueType> rTypes, int Align, dag rList> } multiclass SIRegLoHi16 <string n, bits<8> regIdx, bit ArtificialHigh = 1, - bit isVGPR = 0, bit isAGPR = 0> { + bit isVGPR = 0, bit isAGPR = 0, + list<int> DwarfEncodings = [-1, -1]> { def _LO16 : SIReg<n#".l", regIdx, isVGPR, isAGPR>; def _HI16 : SIReg<!if(ArtificialHigh, "", n#".h"), regIdx, isVGPR, isAGPR, /* isHi16 */ 1> { let isArtificial = ArtificialHigh; } def "" : RegisterWithSubRegs<n, [!cast<Register>(NAME#"_LO16"), - !cast<Register>(NAME#"_HI16")]> { + !cast<Register>(NAME#"_HI16")]>, + DwarfRegNum<DwarfEncodings> { let Namespace = "AMDGPU"; let SubRegIndices = [lo16, hi16]; let CoveredBySubRegs = !not(ArtificialHigh); @@ -197,7 +199,8 @@ def VCC : RegisterWithSubRegs<"vcc", [VCC_LO, VCC_HI]> { let HWEncoding = VCC_LO.HWEncoding; } -defm EXEC_LO : SIRegLoHi16<"exec_lo", 126>, DwarfRegNum<[1, 1]>; +defm EXEC_LO : SIRegLoHi16<"exec_lo", 126, /*ArtificialHigh=*/1, /*isVGPR=*/0, + /*isAGPR=*/0, /*DwarfEncodings=*/[1, 1]>; defm EXEC_HI : SIRegLoHi16<"exec_hi", 127>; def EXEC : RegisterWithSubRegs<"exec", [EXEC_LO, EXEC_HI]>, DwarfRegNum<[17, 1]> { @@ -337,25 +340,26 @@ def FLAT_SCR : FlatReg<FLAT_SCR_LO, FLAT_SCR_HI, 0>; // SGPR registers foreach Index = 0...105 in { defm SGPR#Index : - SIRegLoHi16 <"s"#Index, Index>, - DwarfRegNum<[!if(!le(Index, 63), !add(Index, 32), !add(Index, 1024)), - !if(!le(Index, 63), !add(Index, 32), !add(Index, 1024))]>; + SIRegLoHi16 <"s"#Index, Index, /*ArtificialHigh=*/1, + /*isVGPR=*/0, /*isAGPR=*/0, /*DwarfEncodings=*/ + [!if(!le(Index, 63), !add(Index, 32), !add(Index, 1024)), + !if(!le(Index, 63), !add(Index, 32), !add(Index, 1024))]>; } // VGPR registers foreach Index = 0...255 in { defm VGPR#Index : - SIRegLoHi16 <"v"#Index, Index, /* ArtificialHigh= */ 0, - /* isVGPR= */ 1, /* isAGPR= */ 0>, - DwarfRegNum<[!add(Index, 2560), !add(Index, 1536)]>; + SIRegLoHi16 <"v"#Index, Index, /*ArtificialHigh=*/ 0, + /*isVGPR=*/ 1, /*isAGPR=*/ 0, /*DwarfEncodings=*/ + [!add(Index, 2560), !add(Index, 1536)]>; } // AccVGPR registers foreach Index = 0...255 in { defm AGPR#Index : - SIRegLoHi16 <"a"#Index, Index, /* ArtificialHigh= */ 1, - /* isVGPR= */ 0, /* isAGPR= */ 1>, - DwarfRegNum<[!add(Index, 3072), !add(Index, 2048)]>; + SIRegLoHi16 <"a"#Index, Index, /*ArtificialHigh=*/ 1, + /*isVGPR=*/ 0, /*isAGPR=*/ 1, /*DwarfEncodings=*/ + [!add(Index, 3072), !add(Index, 2048)]>; } //===----------------------------------------------------------------------===// @@ -809,6 +813,9 @@ def SReg_32 : SIRegisterClass<"AMDGPU", [i32, f32, i16, f16, bf16, v2i16, v2f16, let BaseClassOrder = 32; } +def SGPR_NULL128 : SIReg<"null">; +def SGPR_NULL256 : SIReg<"null">; + let GeneratePressureSet = 0 in { def SRegOrLds_32 : SIRegisterClass<"AMDGPU", [i32, f32, i16, f16, bf16, v2i16, v2f16, v2bf16], 32, (add SReg_32, LDS_DIRECT_CLASS)> { @@ -885,6 +892,7 @@ multiclass SRegClass<int numRegs, list<ValueType> regTypes, SIRegisterTuples regList, SIRegisterTuples ttmpList = regList, + bit hasNull = 0, int copyCost = !sra(!add(numRegs, 1), 1)> { defvar hasTTMP = !ne(regList, ttmpList); defvar suffix = !cast<string>(!mul(numRegs, 32)); @@ -901,7 +909,7 @@ multiclass SRegClass<int numRegs, } } - def SReg_ # suffix : + def SReg_ # suffix # !if(hasNull, "_XNULL", ""): SIRegisterClass<"AMDGPU", regTypes, 32, !con((add !cast<RegisterClass>(sgprName)), !if(hasTTMP, @@ -910,15 +918,24 @@ multiclass SRegClass<int numRegs, let isAllocatable = 0; let BaseClassOrder = !mul(numRegs, 32); } + + if hasNull then { + def SReg_ # suffix : + SIRegisterClass<"AMDGPU", regTypes, 32, + (add !cast<RegisterClass>("SReg_" # suffix # "_XNULL"), !cast<Register>("SGPR_NULL" # suffix))> { + let isAllocatable = 0; + let BaseClassOrder = !mul(numRegs, 32); + } + } } } defm "" : SRegClass<3, Reg96Types.types, SGPR_96Regs, TTMP_96Regs>; -defm "" : SRegClass<4, Reg128Types.types, SGPR_128Regs, TTMP_128Regs>; +defm "" : SRegClass<4, Reg128Types.types, SGPR_128Regs, TTMP_128Regs, /*hasNull*/ true>; defm "" : SRegClass<5, [v5i32, v5f32], SGPR_160Regs, TTMP_160Regs>; defm "" : SRegClass<6, [v6i32, v6f32, v3i64, v3f64], SGPR_192Regs, TTMP_192Regs>; defm "" : SRegClass<7, [v7i32, v7f32], SGPR_224Regs, TTMP_224Regs>; -defm "" : SRegClass<8, [v8i32, v8f32, v4i64, v4f64, v16i16, v16f16, v16bf16], SGPR_256Regs, TTMP_256Regs>; +defm "" : SRegClass<8, [v8i32, v8f32, v4i64, v4f64, v16i16, v16f16, v16bf16], SGPR_256Regs, TTMP_256Regs, /*hasNull*/ true>; defm "" : SRegClass<9, [v9i32, v9f32], SGPR_288Regs, TTMP_288Regs>; defm "" : SRegClass<10, [v10i32, v10f32], SGPR_320Regs, TTMP_320Regs>; defm "" : SRegClass<11, [v11i32, v11f32], SGPR_352Regs, TTMP_352Regs>; diff --git a/llvm/lib/Target/AMDGPU/SIShrinkInstructions.cpp b/llvm/lib/Target/AMDGPU/SIShrinkInstructions.cpp index 42df457..979812e 100644 --- a/llvm/lib/Target/AMDGPU/SIShrinkInstructions.cpp +++ b/llvm/lib/Target/AMDGPU/SIShrinkInstructions.cpp @@ -455,6 +455,7 @@ void SIShrinkInstructions::shrinkMadFma(MachineInstr &MI) const { break; case AMDGPU::V_FMA_F16_e64: case AMDGPU::V_FMA_F16_gfx9_e64: + case AMDGPU::V_FMA_F16_gfx9_fake16_e64: NewOpcode = ST->hasTrue16BitInsts() ? AMDGPU::V_FMAAK_F16_fake16 : AMDGPU::V_FMAAK_F16; break; @@ -484,6 +485,7 @@ void SIShrinkInstructions::shrinkMadFma(MachineInstr &MI) const { break; case AMDGPU::V_FMA_F16_e64: case AMDGPU::V_FMA_F16_gfx9_e64: + case AMDGPU::V_FMA_F16_gfx9_fake16_e64: NewOpcode = ST->hasTrue16BitInsts() ? AMDGPU::V_FMAMK_F16_fake16 : AMDGPU::V_FMAMK_F16; break; @@ -956,7 +958,8 @@ bool SIShrinkInstructions::run(MachineFunction &MF) { MI.getOpcode() == AMDGPU::V_FMA_F32_e64 || MI.getOpcode() == AMDGPU::V_MAD_F16_e64 || MI.getOpcode() == AMDGPU::V_FMA_F16_e64 || - MI.getOpcode() == AMDGPU::V_FMA_F16_gfx9_e64) { + MI.getOpcode() == AMDGPU::V_FMA_F16_gfx9_e64 || + MI.getOpcode() == AMDGPU::V_FMA_F16_gfx9_fake16_e64) { shrinkMadFma(MI); continue; } diff --git a/llvm/lib/Target/AMDGPU/SMInstructions.td b/llvm/lib/Target/AMDGPU/SMInstructions.td index 1aeb4e8..37dcc100 100644 --- a/llvm/lib/Target/AMDGPU/SMInstructions.td +++ b/llvm/lib/Target/AMDGPU/SMInstructions.td @@ -332,19 +332,19 @@ defm S_LOAD_I16 : SM_Pseudo_Loads <SReg_64, SReg_32_XM0_XEXEC>; defm S_LOAD_U16 : SM_Pseudo_Loads <SReg_64, SReg_32_XM0_XEXEC>; let is_buffer = 1 in { -defm S_BUFFER_LOAD_DWORD : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; +defm S_BUFFER_LOAD_DWORD : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; // FIXME: exec_lo/exec_hi appear to be allowed for SMRD loads on // SI/CI, bit disallowed for SMEM on VI. -defm S_BUFFER_LOAD_DWORDX2 : SM_Pseudo_Loads <SReg_128, SReg_64_XEXEC>; +defm S_BUFFER_LOAD_DWORDX2 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_64_XEXEC>; let SubtargetPredicate = HasScalarDwordx3Loads in - defm S_BUFFER_LOAD_DWORDX3 : SM_Pseudo_Loads <SReg_128, SReg_96>; -defm S_BUFFER_LOAD_DWORDX4 : SM_Pseudo_Loads <SReg_128, SReg_128>; -defm S_BUFFER_LOAD_DWORDX8 : SM_Pseudo_Loads <SReg_128, SReg_256>; -defm S_BUFFER_LOAD_DWORDX16 : SM_Pseudo_Loads <SReg_128, SReg_512>; -defm S_BUFFER_LOAD_I8 : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_LOAD_U8 : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_LOAD_I16 : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_LOAD_U16 : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; + defm S_BUFFER_LOAD_DWORDX3 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_96>; +defm S_BUFFER_LOAD_DWORDX4 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_128>; +defm S_BUFFER_LOAD_DWORDX8 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_256>; +defm S_BUFFER_LOAD_DWORDX16 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_512>; +defm S_BUFFER_LOAD_I8 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_LOAD_U8 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_LOAD_I16 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_LOAD_U16 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; } let SubtargetPredicate = HasScalarStores in { @@ -353,9 +353,9 @@ defm S_STORE_DWORDX2 : SM_Pseudo_Stores <SReg_64, SReg_64_XEXEC>; defm S_STORE_DWORDX4 : SM_Pseudo_Stores <SReg_64, SReg_128>; let is_buffer = 1 in { -defm S_BUFFER_STORE_DWORD : SM_Pseudo_Stores <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_STORE_DWORDX2 : SM_Pseudo_Stores <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_STORE_DWORDX4 : SM_Pseudo_Stores <SReg_128, SReg_128>; +defm S_BUFFER_STORE_DWORD : SM_Pseudo_Stores <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_STORE_DWORDX2 : SM_Pseudo_Stores <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_STORE_DWORDX4 : SM_Pseudo_Stores <SReg_128_XNULL, SReg_128>; } } // End SubtargetPredicate = HasScalarStores @@ -375,7 +375,7 @@ def S_DCACHE_WB_VOL : SM_Inval_Pseudo <"s_dcache_wb_vol", int_amdgcn_s_dcache_wb defm S_ATC_PROBE : SM_Pseudo_Probe <SReg_64>; let is_buffer = 1 in { -defm S_ATC_PROBE_BUFFER : SM_Pseudo_Probe <SReg_128>; +defm S_ATC_PROBE_BUFFER : SM_Pseudo_Probe <SReg_128_XNULL>; } } // SubtargetPredicate = isGFX8Plus @@ -401,33 +401,33 @@ defm S_SCRATCH_STORE_DWORDX4 : SM_Pseudo_Stores <SReg_64, SReg_128>; let SubtargetPredicate = HasScalarAtomics in { let is_buffer = 1 in { -defm S_BUFFER_ATOMIC_SWAP : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_CMPSWAP : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_ADD : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_SUB : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_SMIN : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_UMIN : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_SMAX : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_UMAX : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_AND : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_OR : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_XOR : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_INC : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_ATOMIC_DEC : SM_Pseudo_Atomics <SReg_128, SReg_32_XM0_XEXEC>; - -defm S_BUFFER_ATOMIC_SWAP_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_CMPSWAP_X2 : SM_Pseudo_Atomics <SReg_128, SReg_128>; -defm S_BUFFER_ATOMIC_ADD_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_SUB_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_SMIN_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_UMIN_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_SMAX_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_UMAX_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_AND_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_OR_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_XOR_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_INC_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; -defm S_BUFFER_ATOMIC_DEC_X2 : SM_Pseudo_Atomics <SReg_128, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_SWAP : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_CMPSWAP : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_ADD : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_SUB : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_SMIN : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_UMIN : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_SMAX : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_UMAX : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_AND : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_OR : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_XOR : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_INC : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_ATOMIC_DEC : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_32_XM0_XEXEC>; + +defm S_BUFFER_ATOMIC_SWAP_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_CMPSWAP_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_128>; +defm S_BUFFER_ATOMIC_ADD_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_SUB_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_SMIN_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_UMIN_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_SMAX_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_UMAX_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_AND_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_OR_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_XOR_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_INC_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; +defm S_BUFFER_ATOMIC_DEC_X2 : SM_Pseudo_Atomics <SReg_128_XNULL, SReg_64_XEXEC>; } defm S_ATOMIC_SWAP : SM_Pseudo_Atomics <SReg_64, SReg_32_XM0_XEXEC>; @@ -470,7 +470,7 @@ def S_PREFETCH_INST : SM_Prefetch_Pseudo <"s_prefetch_inst", SReg_64, 1>; def S_PREFETCH_INST_PC_REL : SM_Prefetch_Pseudo <"s_prefetch_inst_pc_rel", SReg_64, 0>; def S_PREFETCH_DATA : SM_Prefetch_Pseudo <"s_prefetch_data", SReg_64, 1>; def S_PREFETCH_DATA_PC_REL : SM_Prefetch_Pseudo <"s_prefetch_data_pc_rel", SReg_64, 0>; -def S_BUFFER_PREFETCH_DATA : SM_Prefetch_Pseudo <"s_buffer_prefetch_data", SReg_128, 1> { +def S_BUFFER_PREFETCH_DATA : SM_Prefetch_Pseudo <"s_buffer_prefetch_data", SReg_128_XNULL, 1> { let is_buffer = 1; } } // end let SubtargetPredicate = isGFX12Plus diff --git a/llvm/lib/Target/AMDGPU/VOP1Instructions.td b/llvm/lib/Target/AMDGPU/VOP1Instructions.td index 1dd39be..b9c73e6 100644 --- a/llvm/lib/Target/AMDGPU/VOP1Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP1Instructions.td @@ -1018,9 +1018,9 @@ defm V_CLS_I32 : VOP1_Real_FULL_with_name_gfx11_gfx12<0x03b, defm V_SWAP_B16 : VOP1Only_Real_gfx11_gfx12<0x066>; defm V_PERMLANE64_B32 : VOP1Only_Real_gfx11_gfx12<0x067>; defm V_MOV_B16_t16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x01c, "v_mov_b16">; -defm V_NOT_B16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x069, "v_not_b16">; -defm V_CVT_I32_I16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x06a, "v_cvt_i32_i16">; -defm V_CVT_U32_U16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x06b, "v_cvt_u32_u16">; +defm V_NOT_B16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x069, "v_not_b16">; +defm V_CVT_I32_I16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x06a, "v_cvt_i32_i16">; +defm V_CVT_U32_U16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x06b, "v_cvt_u32_u16">; defm V_CVT_F16_U16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x050, "v_cvt_f16_u16">; defm V_CVT_F16_I16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x051, "v_cvt_f16_i16">; @@ -1036,18 +1036,18 @@ defm V_LOG_F16_t16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x057, "v_log_f16" defm V_LOG_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x057, "v_log_f16">; defm V_EXP_F16_t16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x058, "v_exp_f16">; defm V_EXP_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x058, "v_exp_f16">; -defm V_FREXP_MANT_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x059, "v_frexp_mant_f16">; +defm V_FREXP_MANT_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x059, "v_frexp_mant_f16">; defm V_FREXP_EXP_I16_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x05a, "v_frexp_exp_i16_f16">; defm V_FLOOR_F16_t16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x05b, "v_floor_f16">; defm V_FLOOR_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x05b, "v_floor_f16">; defm V_CEIL_F16_t16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x05c, "v_ceil_f16">; defm V_CEIL_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x05c, "v_ceil_f16">; -defm V_TRUNC_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x05d, "v_trunc_f16">; -defm V_RNDNE_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x05e, "v_rndne_f16">; -defm V_FRACT_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x05f, "v_fract_f16">; -defm V_SIN_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x060, "v_sin_f16">; -defm V_COS_F16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x061, "v_cos_f16">; -defm V_SAT_PK_U8_I16_fake16 : VOP1_Real_FULL_t16_gfx11_gfx12<0x062, "v_sat_pk_u8_i16">; +defm V_TRUNC_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x05d, "v_trunc_f16">; +defm V_RNDNE_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x05e, "v_rndne_f16">; +defm V_FRACT_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x05f, "v_fract_f16">; +defm V_SIN_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x060, "v_sin_f16">; +defm V_COS_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x061, "v_cos_f16">; +defm V_SAT_PK_U8_I16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x062, "v_sat_pk_u8_i16">; defm V_CVT_NORM_I16_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x063, "v_cvt_norm_i16_f16">; defm V_CVT_NORM_U16_F16 : VOP1_Real_FULL_t16_and_fake16_gfx11_gfx12<0x064, "v_cvt_norm_u16_f16">; diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 22e4576..24a2eed 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -340,7 +340,7 @@ let FPDPRounding = 1 in { let SubtargetPredicate = isGFX9Plus in { defm V_DIV_FIXUP_F16_gfx9 : VOP3Inst_t16 <"v_div_fixup_f16_gfx9", VOP_F16_F16_F16_F16, AMDGPUdiv_fixup>; - defm V_FMA_F16_gfx9 : VOP3Inst <"v_fma_f16_gfx9", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, any_fma>; + defm V_FMA_F16_gfx9 : VOP3Inst_t16 <"v_fma_f16_gfx9", VOP_F16_F16_F16_F16, any_fma>; } // End SubtargetPredicate = isGFX9Plus } // End FPDPRounding = 1 @@ -1374,8 +1374,8 @@ class VOP3_DOT_Profile_fake16<VOPProfile P, VOP3Features Features = VOP3_REGULAR let SubtargetPredicate = isGFX11Plus in { defm V_MAXMIN_F32 : VOP3Inst<"v_maxmin_f32", VOP3_Profile<VOP_F32_F32_F32_F32>>; defm V_MINMAX_F32 : VOP3Inst<"v_minmax_f32", VOP3_Profile<VOP_F32_F32_F32_F32>>; - defm V_MAXMIN_F16 : VOP3Inst<"v_maxmin_f16", VOP3_Profile<VOP_F16_F16_F16_F16>>; - defm V_MINMAX_F16 : VOP3Inst<"v_minmax_f16", VOP3_Profile<VOP_F16_F16_F16_F16>>; + defm V_MAXMIN_F16 : VOP3Inst_t16<"v_maxmin_f16", VOP_F16_F16_F16_F16>; + defm V_MINMAX_F16 : VOP3Inst_t16<"v_minmax_f16", VOP_F16_F16_F16_F16>; defm V_MAXMIN_U32 : VOP3Inst<"v_maxmin_u32", VOP3_Profile<VOP_I32_I32_I32_I32>>; defm V_MINMAX_U32 : VOP3Inst<"v_minmax_u32", VOP3_Profile<VOP_I32_I32_I32_I32>>; defm V_MAXMIN_I32 : VOP3Inst<"v_maxmin_i32", VOP3_Profile<VOP_I32_I32_I32_I32>>; @@ -1578,8 +1578,8 @@ def : MinimumMaximumByMinimum3Maximum3<fmaximum, f32, V_MAXIMUM3_F32_e64>; defm V_MIN3_NUM_F32 : VOP3_Realtriple_with_name_gfx12<0x229, "V_MIN3_F32", "v_min3_num_f32">; defm V_MAX3_NUM_F32 : VOP3_Realtriple_with_name_gfx12<0x22a, "V_MAX3_F32", "v_max3_num_f32">; -defm V_MIN3_NUM_F16 : VOP3_Realtriple_with_name_gfx12<0x22b, "V_MIN3_F16", "v_min3_num_f16">; -defm V_MAX3_NUM_F16 : VOP3_Realtriple_with_name_gfx12<0x22c, "V_MAX3_F16", "v_max3_num_f16">; +defm V_MIN3_NUM_F16 : VOP3_Realtriple_t16_and_fake16_gfx12<0x22b, "v_min3_num_f16", "V_MIN3_F16", "v_min3_f16">; +defm V_MAX3_NUM_F16 : VOP3_Realtriple_t16_and_fake16_gfx12<0x22c, "v_max3_num_f16", "V_MAX3_F16", "v_max3_f16">; defm V_MINIMUM3_F32 : VOP3Only_Realtriple_gfx12<0x22d>; defm V_MAXIMUM3_F32 : VOP3Only_Realtriple_gfx12<0x22e>; defm V_MINIMUM3_F16 : VOP3Only_Realtriple_t16_gfx12<0x22f>; @@ -1588,8 +1588,8 @@ defm V_MED3_NUM_F32 : VOP3_Realtriple_with_name_gfx12<0x231, "V_MED3_F32", defm V_MED3_NUM_F16 : VOP3_Realtriple_t16_and_fake16_gfx12<0x232, "v_med3_num_f16", "V_MED3_F16", "v_med3_f16">; defm V_MINMAX_NUM_F32 : VOP3_Realtriple_with_name_gfx12<0x268, "V_MINMAX_F32", "v_minmax_num_f32">; defm V_MAXMIN_NUM_F32 : VOP3_Realtriple_with_name_gfx12<0x269, "V_MAXMIN_F32", "v_maxmin_num_f32">; -defm V_MINMAX_NUM_F16 : VOP3_Realtriple_with_name_gfx12<0x26a, "V_MINMAX_F16", "v_minmax_num_f16">; -defm V_MAXMIN_NUM_F16 : VOP3_Realtriple_with_name_gfx12<0x26b, "V_MAXMIN_F16", "v_maxmin_num_f16">; +defm V_MINMAX_NUM_F16 : VOP3_Realtriple_t16_and_fake16_gfx12<0x26a, "v_minmax_num_f16", "V_MINMAX_F16", "v_minmax_f16">; +defm V_MAXMIN_NUM_F16 : VOP3_Realtriple_t16_and_fake16_gfx12<0x26b, "v_maxmin_num_f16", "V_MAXMIN_F16", "v_maxmin_f16">; defm V_MINIMUMMAXIMUM_F32 : VOP3Only_Realtriple_gfx12<0x26c>; defm V_MAXIMUMMINIMUM_F32 : VOP3Only_Realtriple_gfx12<0x26d>; defm V_MINIMUMMAXIMUM_F16 : VOP3Only_Realtriple_t16_gfx12<0x26e>; @@ -1708,7 +1708,7 @@ defm V_PERM_B32 : VOP3_Realtriple_gfx11_gfx12<0x244>; defm V_XAD_U32 : VOP3_Realtriple_gfx11_gfx12<0x245>; defm V_LSHL_ADD_U32 : VOP3_Realtriple_gfx11_gfx12<0x246>; defm V_ADD_LSHL_U32 : VOP3_Realtriple_gfx11_gfx12<0x247>; -defm V_FMA_F16 : VOP3_Realtriple_with_name_gfx11_gfx12<0x248, "V_FMA_F16_gfx9", "v_fma_f16">; +defm V_FMA_F16 : VOP3_Realtriple_t16_and_fake16_gfx11_gfx12<0x248, "v_fma_f16", "V_FMA_F16_gfx9">; defm V_MIN3_F16 : VOP3Only_Realtriple_t16_and_fake16_gfx11<0x249, "v_min3_f16">; defm V_MIN3_I16 : VOP3_Realtriple_t16_and_fake16_gfx11_gfx12<0x24a, "v_min3_i16">; defm V_MIN3_U16 : VOP3_Realtriple_t16_and_fake16_gfx11_gfx12<0x24b, "v_min3_u16">; @@ -1730,8 +1730,8 @@ defm V_PERMLANE16_B32 : VOP3_Real_Base_gfx11_gfx12<0x25b>; defm V_PERMLANEX16_B32 : VOP3_Real_Base_gfx11_gfx12<0x25c>; defm V_MAXMIN_F32 : VOP3_Realtriple_gfx11<0x25e>; defm V_MINMAX_F32 : VOP3_Realtriple_gfx11<0x25f>; -defm V_MAXMIN_F16 : VOP3_Realtriple_gfx11<0x260>; -defm V_MINMAX_F16 : VOP3_Realtriple_gfx11<0x261>; +defm V_MAXMIN_F16 : VOP3_Realtriple_t16_and_fake16_gfx11<0x260, "v_maxmin_f16">; +defm V_MINMAX_F16 : VOP3_Realtriple_t16_and_fake16_gfx11<0x261, "v_minmax_f16">; defm V_MAXMIN_U32 : VOP3_Realtriple_gfx11_gfx12<0x262>; defm V_MINMAX_U32 : VOP3_Realtriple_gfx11_gfx12<0x263>; defm V_MAXMIN_I32 : VOP3_Realtriple_gfx11_gfx12<0x264>; diff --git a/llvm/lib/Target/AMDGPU/VOPCInstructions.td b/llvm/lib/Target/AMDGPU/VOPCInstructions.td index 9bf043e..8589d59 100644 --- a/llvm/lib/Target/AMDGPU/VOPCInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPCInstructions.td @@ -1130,20 +1130,20 @@ defm : ICMP_Pattern <COND_SGE, V_CMP_GE_I64_e64, i64>; defm : ICMP_Pattern <COND_SLT, V_CMP_LT_I64_e64, i64>; defm : ICMP_Pattern <COND_SLE, V_CMP_LE_I64_e64, i64>; -let OtherPredicates = [HasTrue16BitInsts] in { -defm : ICMP_Pattern <COND_EQ, V_CMP_EQ_U16_t16_e64, i16>; -defm : ICMP_Pattern <COND_NE, V_CMP_NE_U16_t16_e64, i16>; -defm : ICMP_Pattern <COND_UGT, V_CMP_GT_U16_t16_e64, i16>; -defm : ICMP_Pattern <COND_UGE, V_CMP_GE_U16_t16_e64, i16>; -defm : ICMP_Pattern <COND_ULT, V_CMP_LT_U16_t16_e64, i16>; -defm : ICMP_Pattern <COND_ULE, V_CMP_LE_U16_t16_e64, i16>; -defm : ICMP_Pattern <COND_SGT, V_CMP_GT_I16_t16_e64, i16>; -defm : ICMP_Pattern <COND_SGE, V_CMP_GE_I16_t16_e64, i16>; -defm : ICMP_Pattern <COND_SLT, V_CMP_LT_I16_t16_e64, i16>; -defm : ICMP_Pattern <COND_SLE, V_CMP_LE_I16_t16_e64, i16>; -} // End OtherPredicates = [HasTrue16BitInsts] - -let OtherPredicates = [NotHasTrue16BitInsts] in { +let True16Predicate = UseFakeTrue16Insts in { +defm : ICMP_Pattern <COND_EQ, V_CMP_EQ_U16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_NE, V_CMP_NE_U16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_UGT, V_CMP_GT_U16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_UGE, V_CMP_GE_U16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_ULT, V_CMP_LT_U16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_ULE, V_CMP_LE_U16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_SGT, V_CMP_GT_I16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_SGE, V_CMP_GE_I16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_SLT, V_CMP_LT_I16_fake16_e64, i16>; +defm : ICMP_Pattern <COND_SLE, V_CMP_LE_I16_fake16_e64, i16>; +} // End True16Predicate = UseFakeTrue16Insts + +let True16Predicate = NotHasTrue16BitInsts in { defm : ICMP_Pattern <COND_EQ, V_CMP_EQ_U16_e64, i16>; defm : ICMP_Pattern <COND_NE, V_CMP_NE_U16_e64, i16>; defm : ICMP_Pattern <COND_UGT, V_CMP_GT_U16_e64, i16>; @@ -1154,7 +1154,7 @@ defm : ICMP_Pattern <COND_SGT, V_CMP_GT_I16_e64, i16>; defm : ICMP_Pattern <COND_SGE, V_CMP_GE_I16_e64, i16>; defm : ICMP_Pattern <COND_SLT, V_CMP_LT_I16_e64, i16>; defm : ICMP_Pattern <COND_SLE, V_CMP_LE_I16_e64, i16>; -} // End OtherPredicates = [NotHasTrue16BitInsts] +} // End True16Predicate = NotHasTrue16BitInsts multiclass FCMP_Pattern <PatFrags cond, Instruction inst, ValueType vt> { let WaveSizePredicate = isWave64 in @@ -1215,25 +1215,25 @@ defm : FCMP_Pattern <COND_UGE, V_CMP_NLT_F64_e64, f64>; defm : FCMP_Pattern <COND_ULT, V_CMP_NGE_F64_e64, f64>; defm : FCMP_Pattern <COND_ULE, V_CMP_NGT_F64_e64, f64>; -let OtherPredicates = [HasTrue16BitInsts] in { -defm : FCMP_Pattern <COND_O, V_CMP_O_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_UO, V_CMP_U_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_OEQ, V_CMP_EQ_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_ONE, V_CMP_NEQ_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_OGT, V_CMP_GT_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_OGE, V_CMP_GE_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_OLT, V_CMP_LT_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_OLE, V_CMP_LE_F16_t16_e64, f16>; - -defm : FCMP_Pattern <COND_UEQ, V_CMP_NLG_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_UNE, V_CMP_NEQ_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_UGT, V_CMP_NLE_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_UGE, V_CMP_NLT_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_ULT, V_CMP_NGE_F16_t16_e64, f16>; -defm : FCMP_Pattern <COND_ULE, V_CMP_NGT_F16_t16_e64, f16>; -} // End OtherPredicates = [HasTrue16BitInsts] - -let OtherPredicates = [NotHasTrue16BitInsts] in { +let True16Predicate = UseFakeTrue16Insts in { +defm : FCMP_Pattern <COND_O, V_CMP_O_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_UO, V_CMP_U_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_OEQ, V_CMP_EQ_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_ONE, V_CMP_NEQ_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_OGT, V_CMP_GT_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_OGE, V_CMP_GE_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_OLT, V_CMP_LT_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_OLE, V_CMP_LE_F16_fake16_e64, f16>; + +defm : FCMP_Pattern <COND_UEQ, V_CMP_NLG_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_UNE, V_CMP_NEQ_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_UGT, V_CMP_NLE_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_UGE, V_CMP_NLT_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_ULT, V_CMP_NGE_F16_fake16_e64, f16>; +defm : FCMP_Pattern <COND_ULE, V_CMP_NGT_F16_fake16_e64, f16>; +} // End True16Predicate = UseFakeTrue16Insts + +let True16Predicate = NotHasTrue16BitInsts in { defm : FCMP_Pattern <COND_O, V_CMP_O_F16_e64, f16>; defm : FCMP_Pattern <COND_UO, V_CMP_U_F16_e64, f16>; defm : FCMP_Pattern <COND_OEQ, V_CMP_EQ_F16_e64, f16>; @@ -1249,7 +1249,7 @@ defm : FCMP_Pattern <COND_UGT, V_CMP_NLE_F16_e64, f16>; defm : FCMP_Pattern <COND_UGE, V_CMP_NLT_F16_e64, f16>; defm : FCMP_Pattern <COND_ULT, V_CMP_NGE_F16_e64, f16>; defm : FCMP_Pattern <COND_ULE, V_CMP_NGT_F16_e64, f16>; -} // End OtherPredicates = [NotHasTrue16BitInsts] +} // End True16Predicate = NotHasTrue16BitInsts //===----------------------------------------------------------------------===// // DPP Encodings @@ -1707,23 +1707,6 @@ multiclass VOPCX_Real_t16_gfx11_gfx12<bits<9> op, string asm_name, VOPCX_Real_t16<GFX11Gen, op, asm_name, OpName, pseudo_mnemonic>, VOPCX_Real_t16<GFX12Gen, op, asm_name, OpName, pseudo_mnemonic>; -defm V_CMP_F_F16_t16 : VOPC_Real_t16_gfx11<0x000, "v_cmp_f_f16">; -defm V_CMP_LT_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x001, "v_cmp_lt_f16">; -defm V_CMP_EQ_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x002, "v_cmp_eq_f16">; -defm V_CMP_LE_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x003, "v_cmp_le_f16">; -defm V_CMP_GT_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x004, "v_cmp_gt_f16">; -defm V_CMP_LG_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x005, "v_cmp_lg_f16">; -defm V_CMP_GE_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x006, "v_cmp_ge_f16">; -defm V_CMP_O_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x007, "v_cmp_o_f16">; -defm V_CMP_U_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x008, "v_cmp_u_f16">; -defm V_CMP_NGE_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x009, "v_cmp_nge_f16">; -defm V_CMP_NLG_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x00a, "v_cmp_nlg_f16">; -defm V_CMP_NGT_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x00b, "v_cmp_ngt_f16">; -defm V_CMP_NLE_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x00c, "v_cmp_nle_f16">; -defm V_CMP_NEQ_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x00d, "v_cmp_neq_f16">; -defm V_CMP_NLT_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x00e, "v_cmp_nlt_f16">; -defm V_CMP_T_F16_t16 : VOPC_Real_t16_gfx11<0x00f, "v_cmp_t_f16", "V_CMP_TRU_F16_t16", "v_cmp_tru_f16">; - defm V_CMP_F_F16_fake16 : VOPC_Real_t16_gfx11<0x000, "v_cmp_f_f16">; defm V_CMP_LT_F16_fake16 : VOPC_Real_t16_gfx11_gfx12<0x001, "v_cmp_lt_f16">; defm V_CMP_EQ_F16_fake16 : VOPC_Real_t16_gfx11_gfx12<0x002, "v_cmp_eq_f16">; @@ -1759,19 +1742,6 @@ defm V_CMP_NLT_F32 : VOPC_Real_gfx11_gfx12<0x01e>; defm V_CMP_T_F32 : VOPC_Real_with_name_gfx11<0x01f, "V_CMP_TRU_F32", "v_cmp_t_f32">; defm V_CMP_T_F64 : VOPC_Real_with_name_gfx11<0x02f, "V_CMP_TRU_F64", "v_cmp_t_f64">; -defm V_CMP_LT_I16_t16 : VOPC_Real_t16_gfx11_gfx12<0x031, "v_cmp_lt_i16">; -defm V_CMP_EQ_I16_t16 : VOPC_Real_t16_gfx11_gfx12<0x032, "v_cmp_eq_i16">; -defm V_CMP_LE_I16_t16 : VOPC_Real_t16_gfx11_gfx12<0x033, "v_cmp_le_i16">; -defm V_CMP_GT_I16_t16 : VOPC_Real_t16_gfx11_gfx12<0x034, "v_cmp_gt_i16">; -defm V_CMP_NE_I16_t16 : VOPC_Real_t16_gfx11_gfx12<0x035, "v_cmp_ne_i16">; -defm V_CMP_GE_I16_t16 : VOPC_Real_t16_gfx11_gfx12<0x036, "v_cmp_ge_i16">; -defm V_CMP_LT_U16_t16 : VOPC_Real_t16_gfx11_gfx12<0x039, "v_cmp_lt_u16">; -defm V_CMP_EQ_U16_t16 : VOPC_Real_t16_gfx11_gfx12<0x03a, "v_cmp_eq_u16">; -defm V_CMP_LE_U16_t16 : VOPC_Real_t16_gfx11_gfx12<0x03b, "v_cmp_le_u16">; -defm V_CMP_GT_U16_t16 : VOPC_Real_t16_gfx11_gfx12<0x03c, "v_cmp_gt_u16">; -defm V_CMP_NE_U16_t16 : VOPC_Real_t16_gfx11_gfx12<0x03d, "v_cmp_ne_u16">; -defm V_CMP_GE_U16_t16 : VOPC_Real_t16_gfx11_gfx12<0x03e, "v_cmp_ge_u16">; - defm V_CMP_LT_I16_fake16 : VOPC_Real_t16_gfx11_gfx12<0x031, "v_cmp_lt_i16">; defm V_CMP_EQ_I16_fake16 : VOPC_Real_t16_gfx11_gfx12<0x032, "v_cmp_eq_i16">; defm V_CMP_LE_I16_fake16 : VOPC_Real_t16_gfx11_gfx12<0x033, "v_cmp_le_i16">; @@ -1819,28 +1789,10 @@ defm V_CMP_NE_U64 : VOPC_Real_gfx11_gfx12<0x05d>; defm V_CMP_GE_U64 : VOPC_Real_gfx11_gfx12<0x05e>; defm V_CMP_T_U64 : VOPC_Real_gfx11<0x05f>; -defm V_CMP_CLASS_F16_t16 : VOPC_Real_t16_gfx11_gfx12<0x07d, "v_cmp_class_f16">; defm V_CMP_CLASS_F16_fake16 : VOPC_Real_t16_gfx11_gfx12<0x07d, "v_cmp_class_f16">; defm V_CMP_CLASS_F32 : VOPC_Real_gfx11_gfx12<0x07e>; defm V_CMP_CLASS_F64 : VOPC_Real_gfx11_gfx12<0x07f>; -defm V_CMPX_F_F16_t16 : VOPCX_Real_t16_gfx11<0x080, "v_cmpx_f_f16">; -defm V_CMPX_LT_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x081, "v_cmpx_lt_f16">; -defm V_CMPX_EQ_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x082, "v_cmpx_eq_f16">; -defm V_CMPX_LE_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x083, "v_cmpx_le_f16">; -defm V_CMPX_GT_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x084, "v_cmpx_gt_f16">; -defm V_CMPX_LG_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x085, "v_cmpx_lg_f16">; -defm V_CMPX_GE_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x086, "v_cmpx_ge_f16">; -defm V_CMPX_O_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x087, "v_cmpx_o_f16">; -defm V_CMPX_U_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x088, "v_cmpx_u_f16">; -defm V_CMPX_NGE_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x089, "v_cmpx_nge_f16">; -defm V_CMPX_NLG_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x08a, "v_cmpx_nlg_f16">; -defm V_CMPX_NGT_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x08b, "v_cmpx_ngt_f16">; -defm V_CMPX_NLE_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x08c, "v_cmpx_nle_f16">; -defm V_CMPX_NEQ_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x08d, "v_cmpx_neq_f16">; -defm V_CMPX_NLT_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x08e, "v_cmpx_nlt_f16">; -defm V_CMPX_T_F16_t16 : VOPCX_Real_with_name_gfx11<0x08f, "V_CMPX_TRU_F16_t16", "v_cmpx_t_f16", "v_cmpx_tru_f16">; - defm V_CMPX_F_F16_fake16 : VOPCX_Real_t16_gfx11<0x080, "v_cmpx_f_f16">; defm V_CMPX_LT_F16_fake16 : VOPCX_Real_t16_gfx11_gfx12<0x081, "v_cmpx_lt_f16">; defm V_CMPX_EQ_F16_fake16 : VOPCX_Real_t16_gfx11_gfx12<0x082, "v_cmpx_eq_f16">; @@ -1892,19 +1844,6 @@ defm V_CMPX_NEQ_F64 : VOPCX_Real_gfx11_gfx12<0x0ad>; defm V_CMPX_NLT_F64 : VOPCX_Real_gfx11_gfx12<0x0ae>; defm V_CMPX_T_F64 : VOPCX_Real_with_name_gfx11<0x0af, "V_CMPX_TRU_F64", "v_cmpx_t_f64">; -defm V_CMPX_LT_I16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0b1, "v_cmpx_lt_i16">; -defm V_CMPX_EQ_I16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0b2, "v_cmpx_eq_i16">; -defm V_CMPX_LE_I16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0b3, "v_cmpx_le_i16">; -defm V_CMPX_GT_I16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0b4, "v_cmpx_gt_i16">; -defm V_CMPX_NE_I16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0b5, "v_cmpx_ne_i16">; -defm V_CMPX_GE_I16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0b6, "v_cmpx_ge_i16">; -defm V_CMPX_LT_U16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0b9, "v_cmpx_lt_u16">; -defm V_CMPX_EQ_U16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0ba, "v_cmpx_eq_u16">; -defm V_CMPX_LE_U16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0bb, "v_cmpx_le_u16">; -defm V_CMPX_GT_U16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0bc, "v_cmpx_gt_u16">; -defm V_CMPX_NE_U16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0bd, "v_cmpx_ne_u16">; -defm V_CMPX_GE_U16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0be, "v_cmpx_ge_u16">; - defm V_CMPX_LT_I16_fake16 : VOPCX_Real_t16_gfx11_gfx12<0x0b1, "v_cmpx_lt_i16">; defm V_CMPX_EQ_I16_fake16 : VOPCX_Real_t16_gfx11_gfx12<0x0b2, "v_cmpx_eq_i16">; defm V_CMPX_LE_I16_fake16 : VOPCX_Real_t16_gfx11_gfx12<0x0b3, "v_cmpx_le_i16">; @@ -1951,7 +1890,6 @@ defm V_CMPX_GT_U64 : VOPCX_Real_gfx11_gfx12<0x0dc>; defm V_CMPX_NE_U64 : VOPCX_Real_gfx11_gfx12<0x0dd>; defm V_CMPX_GE_U64 : VOPCX_Real_gfx11_gfx12<0x0de>; defm V_CMPX_T_U64 : VOPCX_Real_gfx11<0x0df>; -defm V_CMPX_CLASS_F16_t16 : VOPCX_Real_t16_gfx11_gfx12<0x0fd, "v_cmpx_class_f16">; defm V_CMPX_CLASS_F16_fake16 : VOPCX_Real_t16_gfx11_gfx12<0x0fd, "v_cmpx_class_f16">; defm V_CMPX_CLASS_F32 : VOPCX_Real_gfx11_gfx12<0x0fe>; defm V_CMPX_CLASS_F64 : VOPCX_Real_gfx11_gfx12<0x0ff>; diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td index d236907..930ed9a 100644 --- a/llvm/lib/Target/AMDGPU/VOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td @@ -1909,8 +1909,8 @@ multiclass VOP3_Realtriple_t16_gfx11<bits<10> op, string asmName, string opName multiclass VOP3_Realtriple_t16_and_fake16_gfx11<bits<10> op, string asmName, string opName = NAME, string pseudo_mnemonic = "", bit isSingle = 0> { - defm _t16: VOP3_Realtriple_t16_gfx11<op, opName#"_t16", asmName, pseudo_mnemonic, isSingle>; - defm _fake16: VOP3_Realtriple_t16_gfx11<op, opName#"_fake16", asmName, pseudo_mnemonic, isSingle>; + defm _t16: VOP3_Realtriple_t16_gfx11<op, asmName, opName#"_t16", pseudo_mnemonic, isSingle>; + defm _fake16: VOP3_Realtriple_t16_gfx11<op, asmName, opName#"_fake16", pseudo_mnemonic, isSingle>; } multiclass VOP3Only_Realtriple_t16_gfx11<bits<10> op, string asmName, diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 5ec2d83..2e517c2 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -806,7 +806,9 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, MVT::bf16, Custom); } else { setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand); + setOperationAction(ISD::BF16_TO_FP, MVT::f64, Expand); setOperationAction(ISD::FP_TO_BF16, MVT::f32, Custom); + setOperationAction(ISD::FP_TO_BF16, MVT::f64, Custom); } for (MVT VT : MVT::fixedlen_vector_valuetypes()) { diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td index c67177c..009b60c 100644 --- a/llvm/lib/Target/ARM/ARMInstrInfo.td +++ b/llvm/lib/Target/ARM/ARMInstrInfo.td @@ -3320,7 +3320,7 @@ def STRH_preidx: ARMPseudoInst<(outs GPR:$Rn_wb), } - +let mayStore = 1, hasSideEffects = 0 in { def STRH_PRE : AI3ldstidx<0b1011, 0, 1, (outs GPR:$Rn_wb), (ins GPR:$Rt, addrmode3_pre:$addr), IndexModePre, StMiscFrm, IIC_iStore_bh_ru, @@ -3352,6 +3352,7 @@ def STRH_POST : AI3ldstidx<0b1011, 0, 0, (outs GPR:$Rn_wb), let Inst{3-0} = offset{3-0}; // imm3_0/Rm let DecoderMethod = "DecodeAddrMode3Instruction"; } +} // mayStore = 1, hasSideEffects = 0 let mayStore = 1, hasSideEffects = 0, hasExtraSrcRegAllocReq = 1 in { def STRD_PRE : AI3ldstidx<0b1111, 0, 1, (outs GPR:$Rn_wb), diff --git a/llvm/lib/Target/ARM/ARMSystemRegister.td b/llvm/lib/Target/ARM/ARMSystemRegister.td index c03db15..3afc410 100644 --- a/llvm/lib/Target/ARM/ARMSystemRegister.td +++ b/llvm/lib/Target/ARM/ARMSystemRegister.td @@ -19,17 +19,13 @@ class MClassSysReg<bits<1> UniqMask1, bits<1> UniqMask2, bits<1> UniqMask3, bits<12> Enc12, - string name> : SearchableTable { - let SearchableFields = ["Name", "M1Encoding12", "M2M3Encoding8", "Encoding"]; + string name> { string Name; bits<13> M1Encoding12; bits<10> M2M3Encoding8; bits<12> Encoding; let Name = name; - let EnumValueField = "M1Encoding12"; - let EnumValueField = "M2M3Encoding8"; - let EnumValueField = "Encoding"; let M1Encoding12{12} = UniqMask1; let M1Encoding12{11-00} = Enc12; @@ -41,6 +37,27 @@ class MClassSysReg<bits<1> UniqMask1, code Requires = [{ {} }]; } +def MClassSysRegsList : GenericTable { + let FilterClass = "MClassSysReg"; + let Fields = ["Name", "M1Encoding12", "M2M3Encoding8", "Encoding", + "Requires"]; +} + +def lookupMClassSysRegByName : SearchIndex { + let Table = MClassSysRegsList; + let Key = ["Name"]; +} + +def lookupMClassSysRegByM1Encoding12 : SearchIndex { + let Table = MClassSysRegsList; + let Key = ["M1Encoding12"]; +} + +def lookupMClassSysRegByM2M3Encoding8 : SearchIndex { + let Table = MClassSysRegsList; + let Key = ["M2M3Encoding8"]; +} + // [|i|e|x]apsr_nzcvq has alias [|i|e|x]apsr. // Mask1 Mask2 Mask3 Enc12, Name let Requires = [{ {ARM::FeatureDSP} }] in { @@ -127,15 +144,29 @@ def : MClassSysReg<0, 0, 1, 0x8a7, "pac_key_u_3_ns">; // Banked Registers // -class BankedReg<string name, bits<8> enc> - : SearchableTable { +class BankedReg<string name, bits<8> enc> { string Name; bits<8> Encoding; let Name = name; let Encoding = enc; - let SearchableFields = ["Name", "Encoding"]; } +def BankedRegsList : GenericTable { + let FilterClass = "BankedReg"; + let Fields = ["Name", "Encoding"]; +} + +def lookupBankedRegByName : SearchIndex { + let Table = BankedRegsList; + let Key = ["Name"]; +} + +def lookupBankedRegByEncoding : SearchIndex { + let Table = BankedRegsList; + let Key = ["Encoding"]; +} + + // The values here come from B9.2.3 of the ARM ARM, where bits 4-0 are SysM // and bit 5 is R. def : BankedReg<"r8_usr", 0x00>; diff --git a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp index 494c67d..e76a70b 100644 --- a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp +++ b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp @@ -62,13 +62,13 @@ const MClassSysReg *lookupMClassSysRegBy8bitSYSmValue(unsigned SYSm) { return ARMSysReg::lookupMClassSysRegByM2M3Encoding8((1<<8)|(SYSm & 0xFF)); } -#define GET_MCLASSSYSREG_IMPL +#define GET_MClassSysRegsList_IMPL #include "ARMGenSystemRegister.inc" } // end namespace ARMSysReg namespace ARMBankedReg { -#define GET_BANKEDREG_IMPL +#define GET_BankedRegsList_IMPL #include "ARMGenSystemRegister.inc" } // end namespce ARMSysReg } // end namespace llvm diff --git a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h index 5562572..dc4f811 100644 --- a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h +++ b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h @@ -206,8 +206,8 @@ namespace ARMSysReg { } }; - #define GET_MCLASSSYSREG_DECL - #include "ARMGenSystemRegister.inc" +#define GET_MClassSysRegsList_DECL +#include "ARMGenSystemRegister.inc" // lookup system register using 12-bit SYSm value. // Note: the search is uniqued using M1 mask @@ -228,8 +228,8 @@ namespace ARMBankedReg { const char *Name; uint16_t Encoding; }; - #define GET_BANKEDREG_DECL - #include "ARMGenSystemRegister.inc" +#define GET_BankedRegsList_DECL +#include "ARMGenSystemRegister.inc" } // end namespace ARMBankedReg } // end namespace llvm diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 5d865a3..62b5b70 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -42,8 +42,10 @@ def FloatTy : DXILOpParamType; def DoubleTy : DXILOpParamType; def ResRetHalfTy : DXILOpParamType; def ResRetFloatTy : DXILOpParamType; +def ResRetDoubleTy : DXILOpParamType; def ResRetInt16Ty : DXILOpParamType; def ResRetInt32Ty : DXILOpParamType; +def ResRetInt64Ty : DXILOpParamType; def HandleTy : DXILOpParamType; def ResBindTy : DXILOpParamType; def ResPropsTy : DXILOpParamType; @@ -890,6 +892,23 @@ def SplitDouble : DXILOp<102, splitDouble> { let attributes = [Attributes<DXIL1_0, [ReadNone]>]; } +def RawBufferLoad : DXILOp<139, rawBufferLoad> { + let Doc = "reads from a raw buffer and structured buffer"; + // Handle, Coord0, Coord1, Mask, Alignment + let arguments = [HandleTy, Int32Ty, Int32Ty, Int8Ty, Int32Ty]; + let result = OverloadTy; + let overloads = [ + Overloads<DXIL1_2, + [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>, + Overloads<DXIL1_3, + [ + ResRetHalfTy, ResRetFloatTy, ResRetDoubleTy, ResRetInt16Ty, + ResRetInt32Ty, ResRetInt64Ty + ]> + ]; + let stages = [Stages<DXIL1_2, [all_stages]>]; +} + def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> { let Doc = "signed dot product of 4 x i8 vectors packed into i32, with " "accumulate to i32"; diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 5d5bb3e..9f88ccd 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -263,10 +263,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, return getResRetType(Type::getHalfTy(Ctx)); case OpParamType::ResRetFloatTy: return getResRetType(Type::getFloatTy(Ctx)); + case OpParamType::ResRetDoubleTy: + return getResRetType(Type::getDoubleTy(Ctx)); case OpParamType::ResRetInt16Ty: return getResRetType(Type::getInt16Ty(Ctx)); case OpParamType::ResRetInt32Ty: return getResRetType(Type::getInt32Ty(Ctx)); + case OpParamType::ResRetInt64Ty: + return getResRetType(Type::getInt64Ty(Ctx)); case OpParamType::HandleTy: return getHandleType(Ctx); case OpParamType::ResBindTy: diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 4e01dd1..f43815b 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -415,8 +415,16 @@ public: } } - OldResult = cast<Instruction>( - IRB.CreateExtractValue(Op, 0, OldResult->getName())); + if (OldResult->use_empty()) { + // Only the check bit was used, so we're done here. + OldResult->eraseFromParent(); + return Error::success(); + } + + assert(OldResult->hasOneUse() && + isa<ExtractValueInst>(*OldResult->user_begin()) && + "Expected only use to be extract of first element"); + OldResult = cast<Instruction>(*OldResult->user_begin()); OldTy = ST->getElementType(0); } @@ -534,6 +542,48 @@ public: }); } + [[nodiscard]] bool lowerRawBufferLoad(Function &F) { + Triple TT(Triple(M.getTargetTriple())); + VersionTuple DXILVersion = TT.getDXILVersion(); + const DataLayout &DL = F.getDataLayout(); + IRBuilder<> &IRB = OpBuilder.getIRB(); + Type *Int8Ty = IRB.getInt8Ty(); + Type *Int32Ty = IRB.getInt32Ty(); + + return replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + Type *OldTy = cast<StructType>(CI->getType())->getElementType(0); + Type *ScalarTy = OldTy->getScalarType(); + Type *NewRetTy = OpBuilder.getResRetType(ScalarTy); + + Value *Handle = + createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); + Value *Index0 = CI->getArgOperand(1); + Value *Index1 = CI->getArgOperand(2); + uint64_t NumElements = + DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy); + Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements)); + Value *Align = + ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()); + + Expected<CallInst *> OpCall = + DXILVersion >= VersionTuple(1, 2) + ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad, + {Handle, Index0, Index1, Mask, Align}, + CI->getName(), NewRetTy) + : OpBuilder.tryCreateOp(OpCode::BufferLoad, + {Handle, Index0, Index1}, CI->getName(), + NewRetTy); + if (Error E = OpCall.takeError()) + return E; + if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true)) + return E; + + return Error::success(); + }); + } + [[nodiscard]] bool lowerUpdateCounter(Function &F) { IRBuilder<> &IRB = OpBuilder.getIRB(); Type *Int32Ty = IRB.getInt32Ty(); @@ -723,14 +773,14 @@ public: HasErrors |= lowerGetPointer(F); break; case Intrinsic::dx_resource_load_typedbuffer: - HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false); - break; - case Intrinsic::dx_resource_loadchecked_typedbuffer: HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true); break; case Intrinsic::dx_resource_store_typedbuffer: HasErrors |= lowerTypedBufferStore(F); break; + case Intrinsic::dx_resource_load_rawbuffer: + HasErrors |= lowerRawBufferLoad(F); + break; case Intrinsic::dx_resource_updatecounter: HasErrors |= lowerUpdateCounter(F); break; diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp index 1ff8f09..8376249 100644 --- a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp +++ b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp @@ -30,6 +30,9 @@ static void replaceTypedBufferAccess(IntrinsicInst *II, "Unexpected typed buffer type"); Type *ContainedType = HandleType->getTypeParameter(0); + Type *LoadType = + StructType::get(ContainedType, Type::getInt1Ty(II->getContext())); + // We need the size of an element in bytes so that we can calculate the offset // in elements given a total offset in bytes later. Type *ScalarType = ContainedType->getScalarType(); @@ -81,13 +84,15 @@ static void replaceTypedBufferAccess(IntrinsicInst *II, // We're storing a scalar, so we need to load the current value and only // replace the relevant part. auto *Load = Builder.CreateIntrinsic( - ContainedType, Intrinsic::dx_resource_load_typedbuffer, + LoadType, Intrinsic::dx_resource_load_typedbuffer, {II->getOperand(0), II->getOperand(1)}); + auto *Struct = Builder.CreateExtractValue(Load, {0}); + // If we have an offset from seeing a GEP earlier, use it. Value *IndexOp = Current.Index ? Current.Index : ConstantInt::get(Builder.getInt32Ty(), 0); - V = Builder.CreateInsertElement(Load, V, IndexOp); + V = Builder.CreateInsertElement(Struct, V, IndexOp); } else { llvm_unreachable("Store to typed resource has invalid type"); } @@ -101,8 +106,10 @@ static void replaceTypedBufferAccess(IntrinsicInst *II, } else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) { IRBuilder<> Builder(LI); Value *V = Builder.CreateIntrinsic( - ContainedType, Intrinsic::dx_resource_load_typedbuffer, + LoadType, Intrinsic::dx_resource_load_typedbuffer, {II->getOperand(0), II->getOperand(1)}); + V = Builder.CreateExtractValue(V, {0}); + if (Current.Index) V = Builder.CreateExtractElement(V, Current.Index); diff --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp index 45aadac..be68d46 100644 --- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp +++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp @@ -749,8 +749,8 @@ uint64_t DXILBitcodeWriter::getOptimizationFlags(const Value *V) { if (PEO->isExact()) Flags |= 1 << bitc::PEO_EXACT; } else if (const auto *FPMO = dyn_cast<FPMathOperator>(V)) { - if (FPMO->hasAllowReassoc()) - Flags |= bitc::AllowReassoc; + if (FPMO->hasAllowReassoc() || FPMO->hasAllowContract()) + Flags |= bitc::UnsafeAlgebra; if (FPMO->hasNoNaNs()) Flags |= bitc::NoNaNs; if (FPMO->hasNoInfs()) @@ -759,10 +759,6 @@ uint64_t DXILBitcodeWriter::getOptimizationFlags(const Value *V) { Flags |= bitc::NoSignedZeros; if (FPMO->hasAllowReciprocal()) Flags |= bitc::AllowReciprocal; - if (FPMO->hasAllowContract()) - Flags |= bitc::AllowContract; - if (FPMO->hasApproxFunc()) - Flags |= bitc::ApproxFunc; } return Flags; diff --git a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp index 46a8ab3..991ee5b 100644 --- a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp +++ b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp @@ -1796,6 +1796,8 @@ bool PolynomialMultiplyRecognize::recognize() { IterCount = CV->getValue()->getZExtValue() + 1; Value *CIV = getCountIV(LoopB); + if (CIV == nullptr) + return false; ParsedValues PV; Simplifier PreSimp; PV.IterCount = IterCount; diff --git a/llvm/lib/Target/LoongArch/LoongArchExpandPseudoInsts.cpp b/llvm/lib/Target/LoongArch/LoongArchExpandPseudoInsts.cpp index 30742c7..0218934 100644 --- a/llvm/lib/Target/LoongArch/LoongArchExpandPseudoInsts.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchExpandPseudoInsts.cpp @@ -352,11 +352,13 @@ bool LoongArchPreRAExpandPseudo::expandLoadAddressTLSLE( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, MachineBasicBlock::iterator &NextMBBI) { // Code Sequence: + // lu12i.w $rd, %le_hi20_r(sym) + // add.w/d $rd, $rd, $tp, %le_add_r(sym) + // addi.w/d $rd, $rd, %le_lo12_r(sym) + // + // Code Sequence while using the large code model: // lu12i.w $rd, %le_hi20(sym) // ori $rd, $rd, %le_lo12(sym) - // - // And additionally if generating code using the large code model: - // // lu32i.d $rd, %le64_lo20(sym) // lu52i.d $rd, $rd, %le64_hi12(sym) MachineFunction *MF = MBB.getParent(); @@ -366,20 +368,35 @@ bool LoongArchPreRAExpandPseudo::expandLoadAddressTLSLE( bool Large = MF->getTarget().getCodeModel() == CodeModel::Large; Register DestReg = MI.getOperand(0).getReg(); Register Parts01 = - Large ? MF->getRegInfo().createVirtualRegister(&LoongArch::GPRRegClass) - : DestReg; + MF->getRegInfo().createVirtualRegister(&LoongArch::GPRRegClass); Register Part1 = MF->getRegInfo().createVirtualRegister(&LoongArch::GPRRegClass); MachineOperand &Symbol = MI.getOperand(1); - BuildMI(MBB, MBBI, DL, TII->get(LoongArch::LU12I_W), Part1) - .addDisp(Symbol, 0, LoongArchII::MO_LE_HI); + if (!Large) { + BuildMI(MBB, MBBI, DL, TII->get(LoongArch::LU12I_W), Part1) + .addDisp(Symbol, 0, LoongArchII::MO_LE_HI_R); - BuildMI(MBB, MBBI, DL, TII->get(LoongArch::ORI), Parts01) - .addReg(Part1, RegState::Kill) - .addDisp(Symbol, 0, LoongArchII::MO_LE_LO); + const auto &STI = MF->getSubtarget<LoongArchSubtarget>(); + unsigned AddOp = STI.is64Bit() ? LoongArch::PseudoAddTPRel_D + : LoongArch::PseudoAddTPRel_W; + BuildMI(MBB, MBBI, DL, TII->get(AddOp), Parts01) + .addReg(Part1, RegState::Kill) + .addReg(LoongArch::R2) + .addDisp(Symbol, 0, LoongArchII::MO_LE_ADD_R); + + unsigned AddiOp = STI.is64Bit() ? LoongArch::ADDI_D : LoongArch::ADDI_W; + BuildMI(MBB, MBBI, DL, TII->get(AddiOp), DestReg) + .addReg(Parts01, RegState::Kill) + .addDisp(Symbol, 0, LoongArchII::MO_LE_LO_R); + } else { + BuildMI(MBB, MBBI, DL, TII->get(LoongArch::LU12I_W), Part1) + .addDisp(Symbol, 0, LoongArchII::MO_LE_HI); + + BuildMI(MBB, MBBI, DL, TII->get(LoongArch::ORI), Parts01) + .addReg(Part1, RegState::Kill) + .addDisp(Symbol, 0, LoongArchII::MO_LE_LO); - if (Large) { Register Parts012 = MF->getRegInfo().createVirtualRegister(&LoongArch::GPRRegClass); diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 7f67def..96e6f71 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -1866,9 +1866,17 @@ SDValue LoongArchTargetLowering::getStaticTLSAddr(GlobalAddressSDNode *N, // PseudoLA_*_LARGE nodes. SDValue Tmp = DAG.getConstant(0, DL, Ty); SDValue Addr = DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, 0); - SDValue Offset = Large + + // Only IE needs an extra argument for large code model. + SDValue Offset = Opc == LoongArch::PseudoLA_TLS_IE_LARGE ? SDValue(DAG.getMachineNode(Opc, DL, Ty, Tmp, Addr), 0) : SDValue(DAG.getMachineNode(Opc, DL, Ty, Addr), 0); + + // If it is LE for normal/medium code model, the add tp operation will occur + // during the pseudo-instruction expansion. + if (Opc == LoongArch::PseudoLA_TLS_LE && !Large) + return Offset; + if (UseGOT) { // Mark the load instruction as invariant to enable hoisting in MachineLICM. MachineFunction &MF = DAG.getMachineFunction(); @@ -1989,7 +1997,7 @@ LoongArchTargetLowering::lowerGlobalTLSAddress(SDValue Op, // // This node doesn't need an extra argument for the large code model. return getStaticTLSAddr(N, DAG, LoongArch::PseudoLA_TLS_LE, - /*UseGOT=*/false); + /*UseGOT=*/false, Large); } return getTLSDescAddr(N, DAG, diff --git a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.cpp b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.cpp index 363cacf..32bc8bb 100644 --- a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.cpp @@ -154,6 +154,9 @@ void LoongArchInstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB, Register VReg) const { MachineFunction *MF = MBB.getParent(); MachineFrameInfo &MFI = MF->getFrameInfo(); + DebugLoc DL; + if (I != MBB.end()) + DL = I->getDebugLoc(); unsigned Opcode; if (LoongArch::GPRRegClass.hasSubClassEq(RC)) @@ -177,7 +180,7 @@ void LoongArchInstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB, MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOLoad, MFI.getObjectSize(FI), MFI.getObjectAlign(FI)); - BuildMI(MBB, I, DebugLoc(), get(Opcode), DstReg) + BuildMI(MBB, I, DL, get(Opcode), DstReg) .addFrameIndex(FI) .addImm(0) .addMemOperand(MMO); @@ -406,6 +409,11 @@ bool LoongArchInstrInfo::isSchedulingBoundary(const MachineInstr &MI, // lu32i.d $a1, %ie64_pc_lo20(s) // lu52i.d $a1, $a1, %ie64_pc_hi12(s) // + // * pcalau12i $a0, %desc_pc_hi20(s) + // addi.d $a1, $zero, %desc_pc_lo12(s) + // lu32i.d $a1, %desc64_pc_lo20(s) + // lu52i.d $a1, $a1, %desc64_pc_hi12(s) + // // For simplicity, only pcalau12i and lu52i.d are marked as scheduling // boundaries, and the instructions between them are guaranteed to be // ordered according to data dependencies. @@ -430,12 +438,16 @@ bool LoongArchInstrInfo::isSchedulingBoundary(const MachineInstr &MI, if (MO0 == LoongArchII::MO_IE_PC_HI && MO1 == LoongArchII::MO_IE_PC_LO && MO2 == LoongArchII::MO_IE_PC64_LO) return true; + if (MO0 == LoongArchII::MO_DESC_PC_HI && + MO1 == LoongArchII::MO_DESC_PC_LO && + MO2 == LoongArchII::MO_DESC64_PC_LO) + return true; break; } case LoongArch::LU52I_D: { auto MO = MI.getOperand(2).getTargetFlags(); if (MO == LoongArchII::MO_PCREL64_HI || MO == LoongArchII::MO_GOT_PC64_HI || - MO == LoongArchII::MO_IE_PC64_HI) + MO == LoongArchII::MO_IE_PC64_HI || MO == LoongArchII::MO_DESC64_PC_HI) return true; break; } @@ -651,7 +663,10 @@ LoongArchInstrInfo::getSerializableDirectMachineOperandTargetFlags() const { {MO_DESC_LD, "loongarch-desc-ld"}, {MO_DESC_CALL, "loongarch-desc-call"}, {MO_LD_PC_HI, "loongarch-ld-pc-hi"}, - {MO_GD_PC_HI, "loongarch-gd-pc-hi"}}; + {MO_GD_PC_HI, "loongarch-gd-pc-hi"}, + {MO_LE_HI_R, "loongarch-le-hi-r"}, + {MO_LE_ADD_R, "loongarch-le-add-r"}, + {MO_LE_LO_R, "loongarch-le-lo-r"}}; return ArrayRef(TargetFlags); } diff --git a/llvm/lib/Target/LoongArch/LoongArchMCInstLower.cpp b/llvm/lib/Target/LoongArch/LoongArchMCInstLower.cpp index 2bacc12..d1de060 100644 --- a/llvm/lib/Target/LoongArch/LoongArchMCInstLower.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchMCInstLower.cpp @@ -114,6 +114,15 @@ static MCOperand lowerSymbolOperand(const MachineOperand &MO, MCSymbol *Sym, case LoongArchII::MO_DESC_CALL: Kind = LoongArchMCExpr::VK_LoongArch_TLS_DESC_CALL; break; + case LoongArchII::MO_LE_HI_R: + Kind = LoongArchMCExpr::VK_LoongArch_TLS_LE_HI20_R; + break; + case LoongArchII::MO_LE_ADD_R: + Kind = LoongArchMCExpr::VK_LoongArch_TLS_LE_ADD_R; + break; + case LoongArchII::MO_LE_LO_R: + Kind = LoongArchMCExpr::VK_LoongArch_TLS_LE_LO12_R; + break; // TODO: Handle more target-flags. } diff --git a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchBaseInfo.h b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchBaseInfo.h index bd63c5e..2369904 100644 --- a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchBaseInfo.h +++ b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchBaseInfo.h @@ -54,6 +54,9 @@ enum { MO_DESC64_PC_LO, MO_DESC_LD, MO_DESC_CALL, + MO_LE_HI_R, + MO_LE_ADD_R, + MO_LE_LO_R, // TODO: Add more flags. }; diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index 65e1893..d34f45f 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -14,7 +14,7 @@ #include "NVPTX.h" #include "NVPTXUtilities.h" #include "llvm/ADT/StringRef.h" -#include "llvm/IR/NVVMIntrinsicFlags.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/MC/MCExpr.h" #include "llvm/MC/MCInst.h" #include "llvm/MC/MCInstrInfo.h" diff --git a/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp index f940dc0..c03ef8d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp @@ -14,6 +14,7 @@ #include "MCTargetDesc/NVPTXBaseInfo.h" #include "NVPTX.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" @@ -49,39 +50,34 @@ static std::string getHash(StringRef Str) { return llvm::utohexstr(Hash.low(), /*LowerCase=*/true); } -static void addKernelMetadata(Module &M, GlobalValue *GV) { +static void addKernelMetadata(Module &M, Function *F) { llvm::LLVMContext &Ctx = M.getContext(); // Get "nvvm.annotations" metadata node. llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); - llvm::Metadata *KernelMDVals[] = { - llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "kernel"), - llvm::ConstantAsMetadata::get( - llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; - // This kernel is only to be called single-threaded. llvm::Metadata *ThreadXMDVals[] = { - llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidx"), + llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidx"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; llvm::Metadata *ThreadYMDVals[] = { - llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidy"), + llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidy"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; llvm::Metadata *ThreadZMDVals[] = { - llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidz"), + llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidz"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; llvm::Metadata *BlockMDVals[] = { - llvm::ConstantAsMetadata::get(GV), + llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxclusterrank"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; // Append metadata to nvvm.annotations. - MD->addOperand(llvm::MDNode::get(Ctx, KernelMDVals)); + F->setCallingConv(CallingConv::PTX_Kernel); MD->addOperand(llvm::MDNode::get(Ctx, ThreadXMDVals)); MD->addOperand(llvm::MDNode::get(Ctx, ThreadYMDVals)); MD->addOperand(llvm::MDNode::get(Ctx, ThreadZMDVals)); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index c51729e..ef97844 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -14,10 +14,11 @@ #include "NVPTXUtilities.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/ISDOpcodes.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/NVVMIntrinsicFlags.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" @@ -2449,6 +2450,11 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) { return true; } +static inline bool isAddLike(const SDValue V) { + return V.getOpcode() == ISD::ADD || + (V->getOpcode() == ISD::OR && V->getFlags().hasDisjoint()); +} + // SelectDirectAddr - Match a direct address for DAG. // A direct address could be a globaladdress or externalsymbol. bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) { @@ -2475,7 +2481,7 @@ bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) { // symbol+offset bool NVPTXDAGToDAGISel::SelectADDRsi_imp( SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset, MVT mvt) { - if (Addr.getOpcode() == ISD::ADD) { + if (isAddLike(Addr)) { if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) { SDValue base = Addr.getOperand(0); if (SelectDirectAddr(base, Base)) { @@ -2512,7 +2518,7 @@ bool NVPTXDAGToDAGISel::SelectADDRri_imp( Addr.getOpcode() == ISD::TargetGlobalAddress) return false; // direct calls. - if (Addr.getOpcode() == ISD::ADD) { + if (isAddLike(Addr)) { if (SelectDirectAddr(Addr.getOperand(0), Addr)) { return false; } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 4a98fe2..c9b7e87 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -261,6 +261,9 @@ public: return true; } + bool isFAbsFree(EVT VT) const override { return true; } + bool isFNegFree(EVT VT) const override { return true; } + private: const NVPTXSubtarget &STI; // cache the subtarget here SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 711cd67..c3e72d6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -733,12 +733,12 @@ def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{ def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse f32:$lo)), (bf16 (fpround_oneuse f32:$hi)))), - (CVT_bf16x2_f32 Float32Regs:$hi, Float32Regs:$lo, CvtRN)>, + (CVT_bf16x2_f32 $hi, $lo, CvtRN)>, Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>; def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse f32:$lo)), (f16 (fpround_oneuse f32:$hi)))), - (CVT_f16x2_f32 Float32Regs:$hi, Float32Regs:$lo, CvtRN)>, + (CVT_f16x2_f32 $hi, $lo, CvtRN)>, Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>; //----------------------------------- @@ -813,7 +813,7 @@ defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>; foreach vt = [v2f16, v2bf16, v2i16, v4i8] in { def : Pat<(vt (select i1:$p, vt:$a, vt:$b)), - (SELP_b32rr Int32Regs:$a, Int32Regs:$b, Int1Regs:$p)>; + (SELP_b32rr $a, $b, $p)>; } //----------------------------------- @@ -952,29 +952,29 @@ def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>; // Matchers for signed, unsigned mul.wide ISD nodes. def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), - (MULWIDES32 i16:$a, i16:$b)>, + (MULWIDES32 $a, $b)>, Requires<[doMulWide]>; def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), - (MULWIDES32Imm Int16Regs:$a, imm:$b)>, + (MULWIDES32Imm $a, imm:$b)>, Requires<[doMulWide]>; def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), - (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>, + (MULWIDEU32 $a, $b)>, Requires<[doMulWide]>; def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), - (MULWIDEU32Imm Int16Regs:$a, imm:$b)>, + (MULWIDEU32Imm $a, imm:$b)>, Requires<[doMulWide]>; def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), - (MULWIDES64 Int32Regs:$a, Int32Regs:$b)>, + (MULWIDES64 $a, $b)>, Requires<[doMulWide]>; def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), - (MULWIDES64Imm Int32Regs:$a, imm:$b)>, + (MULWIDES64Imm $a, imm:$b)>, Requires<[doMulWide]>; def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), - (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>, + (MULWIDEU64 $a, $b)>, Requires<[doMulWide]>; def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), - (MULWIDEU64Imm Int32Regs:$a, imm:$b)>, + (MULWIDEU64Imm $a, imm:$b)>, Requires<[doMulWide]>; // Predicates used for converting some patterns to mul.wide. @@ -1024,46 +1024,46 @@ def SHL2MUL16 : SDNodeXForm<imm, [{ // Convert "sign/zero-extend, then shift left by an immediate" to mul.wide. def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)), - (MULWIDES64Imm Int32Regs:$a, (SHL2MUL32 node:$b))>, + (MULWIDES64Imm $a, (SHL2MUL32 $b))>, Requires<[doMulWide]>; def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)), - (MULWIDEU64Imm Int32Regs:$a, (SHL2MUL32 node:$b))>, + (MULWIDEU64Imm $a, (SHL2MUL32 $b))>, Requires<[doMulWide]>; def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)), - (MULWIDES32Imm Int16Regs:$a, (SHL2MUL16 node:$b))>, + (MULWIDES32Imm $a, (SHL2MUL16 $b))>, Requires<[doMulWide]>; def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)), - (MULWIDEU32Imm Int16Regs:$a, (SHL2MUL16 node:$b))>, + (MULWIDEU32Imm $a, (SHL2MUL16 $b))>, Requires<[doMulWide]>; // Convert "sign/zero-extend then multiply" to mul.wide. def : Pat<(mul (sext i32:$a), (sext i32:$b)), - (MULWIDES64 Int32Regs:$a, Int32Regs:$b)>, + (MULWIDES64 $a, $b)>, Requires<[doMulWide]>; def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)), - (MULWIDES64Imm64 Int32Regs:$a, (i64 SInt32Const:$b))>, + (MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>, Requires<[doMulWide]>; def : Pat<(mul (zext i32:$a), (zext i32:$b)), - (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>, + (MULWIDEU64 $a, $b)>, Requires<[doMulWide]>; def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)), - (MULWIDEU64Imm64 Int32Regs:$a, (i64 UInt32Const:$b))>, + (MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>, Requires<[doMulWide]>; def : Pat<(mul (sext i16:$a), (sext i16:$b)), - (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>, + (MULWIDES32 $a, $b)>, Requires<[doMulWide]>; def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)), - (MULWIDES32Imm32 Int16Regs:$a, (i32 SInt16Const:$b))>, + (MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>, Requires<[doMulWide]>; def : Pat<(mul (zext i16:$a), (zext i16:$b)), - (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>, + (MULWIDEU32 $a, $b)>, Requires<[doMulWide]>; def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)), - (MULWIDEU32Imm32 Int16Regs:$a, (i32 UInt16Const:$b))>, + (MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>, Requires<[doMulWide]>; // @@ -1242,7 +1242,7 @@ def FDIV64ri : // fdiv will be converted to rcp // fneg (fdiv 1.0, X) => fneg (rcp.rn X) def : Pat<(fdiv DoubleConstNeg1:$a, f64:$b), - (FNEGf64 (FDIV641r (NegDoubleConst node:$a), Float64Regs:$b))>; + (FNEGf64 (FDIV641r (NegDoubleConst node:$a), $b))>; // // F32 Approximate reciprocal @@ -1436,83 +1436,83 @@ def COSF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src), // frem - f32 FTZ def : Pat<(frem f32:$x, f32:$y), - (FSUBf32rr_ftz Float32Regs:$x, (FMULf32rr_ftz (CVT_f32_f32 - (FDIV32rr_prec_ftz Float32Regs:$x, Float32Regs:$y), CvtRZI_FTZ), - Float32Regs:$y))>, + (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32 + (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ), + $y))>, Requires<[doF32FTZ, allowUnsafeFPMath]>; def : Pat<(frem f32:$x, fpimm:$y), - (FSUBf32rr_ftz Float32Regs:$x, (FMULf32ri_ftz (CVT_f32_f32 - (FDIV32ri_prec_ftz Float32Regs:$x, fpimm:$y), CvtRZI_FTZ), + (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32 + (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ), fpimm:$y))>, Requires<[doF32FTZ, allowUnsafeFPMath]>; -def : Pat<(frem f32:$x, Float32Regs:$y), - (SELP_f32rr Float32Regs:$x, - (FSUBf32rr_ftz Float32Regs:$x, (FMULf32rr_ftz (CVT_f32_f32 - (FDIV32rr_prec_ftz Float32Regs:$x, Float32Regs:$y), CvtRZI_FTZ), - Float32Regs:$y)), - (TESTINF_f32r Float32Regs:$y))>, +def : Pat<(frem f32:$x, f32:$y), + (SELP_f32rr $x, + (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32 + (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ), + $y)), + (TESTINF_f32r $y))>, Requires<[doF32FTZ, noUnsafeFPMath]>; def : Pat<(frem f32:$x, fpimm:$y), - (SELP_f32rr Float32Regs:$x, - (FSUBf32rr_ftz Float32Regs:$x, (FMULf32ri_ftz (CVT_f32_f32 - (FDIV32ri_prec_ftz Float32Regs:$x, fpimm:$y), CvtRZI_FTZ), + (SELP_f32rr $x, + (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32 + (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ), fpimm:$y)), (TESTINF_f32i fpimm:$y))>, Requires<[doF32FTZ, noUnsafeFPMath]>; // frem - f32 def : Pat<(frem f32:$x, f32:$y), - (FSUBf32rr Float32Regs:$x, (FMULf32rr (CVT_f32_f32 - (FDIV32rr_prec Float32Regs:$x, Float32Regs:$y), CvtRZI), - Float32Regs:$y))>, + (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32 + (FDIV32rr_prec $x, $y), CvtRZI), + $y))>, Requires<[allowUnsafeFPMath]>; def : Pat<(frem f32:$x, fpimm:$y), - (FSUBf32rr Float32Regs:$x, (FMULf32ri (CVT_f32_f32 - (FDIV32ri_prec Float32Regs:$x, fpimm:$y), CvtRZI), + (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32 + (FDIV32ri_prec $x, fpimm:$y), CvtRZI), fpimm:$y))>, Requires<[allowUnsafeFPMath]>; def : Pat<(frem f32:$x, f32:$y), - (SELP_f32rr Float32Regs:$x, - (FSUBf32rr Float32Regs:$x, (FMULf32rr (CVT_f32_f32 - (FDIV32rr_prec Float32Regs:$x, Float32Regs:$y), CvtRZI), - Float32Regs:$y)), + (SELP_f32rr $x, + (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32 + (FDIV32rr_prec $x, $y), CvtRZI), + $y)), (TESTINF_f32r Float32Regs:$y))>, Requires<[noUnsafeFPMath]>; def : Pat<(frem f32:$x, fpimm:$y), - (SELP_f32rr Float32Regs:$x, - (FSUBf32rr Float32Regs:$x, (FMULf32ri (CVT_f32_f32 - (FDIV32ri_prec Float32Regs:$x, fpimm:$y), CvtRZI), + (SELP_f32rr $x, + (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32 + (FDIV32ri_prec $x, fpimm:$y), CvtRZI), fpimm:$y)), (TESTINF_f32i fpimm:$y))>, Requires<[noUnsafeFPMath]>; // frem - f64 def : Pat<(frem f64:$x, f64:$y), - (FSUBf64rr Float64Regs:$x, (FMULf64rr (CVT_f64_f64 - (FDIV64rr Float64Regs:$x, Float64Regs:$y), CvtRZI), - Float64Regs:$y))>, + (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64 + (FDIV64rr $x, $y), CvtRZI), + $y))>, Requires<[allowUnsafeFPMath]>; def : Pat<(frem f64:$x, fpimm:$y), - (FSUBf64rr Float64Regs:$x, (FMULf64ri (CVT_f64_f64 - (FDIV64ri Float64Regs:$x, fpimm:$y), CvtRZI), + (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64 + (FDIV64ri $x, fpimm:$y), CvtRZI), fpimm:$y))>, Requires<[allowUnsafeFPMath]>; def : Pat<(frem f64:$x, f64:$y), - (SELP_f64rr Float64Regs:$x, - (FSUBf64rr Float64Regs:$x, (FMULf64rr (CVT_f64_f64 - (FDIV64rr Float64Regs:$x, Float64Regs:$y), CvtRZI), - Float64Regs:$y)), + (SELP_f64rr $x, + (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64 + (FDIV64rr $x, $y), CvtRZI), + $y)), (TESTINF_f64r Float64Regs:$y))>, Requires<[noUnsafeFPMath]>; def : Pat<(frem f64:$x, fpimm:$y), - (SELP_f64rr Float64Regs:$x, - (FSUBf64rr Float64Regs:$x, (FMULf64ri (CVT_f64_f64 - (FDIV64ri Float64Regs:$x, fpimm:$y), CvtRZI), + (SELP_f64rr $x, + (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64 + (FDIV64ri $x, fpimm:$y), CvtRZI), fpimm:$y)), - (TESTINF_f64r Float64Regs:$y))>, + (TESTINF_f64r $y))>, Requires<[noUnsafeFPMath]>; //----------------------------------- @@ -1561,32 +1561,32 @@ defm AND : BITWISE<"and", and>; defm XOR : BITWISE<"xor", xor>; // PTX does not support mul on predicates, convert to and instructions -def : Pat<(mul i1:$a, i1:$b), (ANDb1rr Int1Regs:$a, Int1Regs:$b)>; -def : Pat<(mul i1:$a, imm:$b), (ANDb1ri Int1Regs:$a, imm:$b)>; +def : Pat<(mul i1:$a, i1:$b), (ANDb1rr $a, $b)>; +def : Pat<(mul i1:$a, imm:$b), (ANDb1ri $a, imm:$b)>; // These transformations were once reliably performed by instcombine, but thanks // to poison semantics they are no longer safe for LLVM IR, perform them here // instead. -def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr Int1Regs:$a, Int1Regs:$b)>; -def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr Int1Regs:$a, Int1Regs:$b)>; +def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr $a, $b)>; +def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr $a, $b)>; // Lower logical v2i16/v4i8 ops as bitwise ops on b32. foreach vt = [v2i16, v4i8] in { def: Pat<(or vt:$a, vt:$b), - (ORb32rr Int32Regs:$a, Int32Regs:$b)>; + (ORb32rr $a, $b)>; def: Pat<(xor vt:$a, vt:$b), - (XORb32rr Int32Regs:$a, Int32Regs:$b)>; + (XORb32rr $a, $b)>; def: Pat<(and vt:$a, vt:$b), - (ANDb32rr Int32Regs:$a, Int32Regs:$b)>; + (ANDb32rr $a, $b)>; // The constants get legalized into a bitcast from i32, so that's what we need // to match here. def: Pat<(or vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ORb32ri Int32Regs:$a, imm:$b)>; + (ORb32ri $a, imm:$b)>; def: Pat<(xor vt:$a, (vt (bitconvert (i32 imm:$b)))), - (XORb32ri Int32Regs:$a, imm:$b)>; + (XORb32ri $a, imm:$b)>; def: Pat<(and vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ANDb32ri Int32Regs:$a, imm:$b)>; + (ANDb32ri $a, imm:$b)>; } def NOT1 : NVPTXInst<(outs Int1Regs:$dst), (ins Int1Regs:$src), @@ -1770,34 +1770,34 @@ let hasSideEffects = false in { // byte extraction + signed/unsigned extension to i32. def : Pat<(i32 (sext_inreg (bfe i32:$s, i32:$o, 8), i8)), - (BFE_S32rri Int32Regs:$s, Int32Regs:$o, 8)>; + (BFE_S32rri $s, $o, 8)>; def : Pat<(i32 (sext_inreg (bfe i32:$s, imm:$o, 8), i8)), - (BFE_S32rii Int32Regs:$s, imm:$o, 8)>; + (BFE_S32rii $s, imm:$o, 8)>; def : Pat<(i32 (and (bfe i32:$s, i32:$o, 8), 255)), - (BFE_U32rri Int32Regs:$s, Int32Regs:$o, 8)>; + (BFE_U32rri $s, $o, 8)>; def : Pat<(i32 (and (bfe i32:$s, imm:$o, 8), 255)), - (BFE_U32rii Int32Regs:$s, imm:$o, 8)>; + (BFE_U32rii $s, imm:$o, 8)>; // byte extraction + signed extension to i16 def : Pat<(i16 (sext_inreg (trunc (bfe i32:$s, imm:$o, 8)), i8)), - (CVT_s8_s32 (BFE_S32rii i32:$s, imm:$o, 8), CvtNONE)>; + (CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>; // Byte extraction via shift/trunc/sext def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), - (CVT_s8_s32 Int32Regs:$s, CvtNONE)>; + (CVT_s8_s32 $s, CvtNONE)>; def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)), - (CVT_s8_s32 (BFE_S32rii Int32Regs:$s, imm:$o, 8), CvtNONE)>; + (CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>; def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), - (BFE_S32rii Int32Regs:$s, imm:$o, 8)>; + (BFE_S32rii $s, imm:$o, 8)>; def : Pat<(i16 (sra (i16 (trunc i32:$s)), (i32 8))), - (CVT_s8_s32 (BFE_S32rii Int32Regs:$s, 8, 8), CvtNONE)>; + (CVT_s8_s32 (BFE_S32rii $s, 8, 8), CvtNONE)>; def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), - (BFE_S64rii Int64Regs:$s, imm:$o, 8)>; + (BFE_S64rii $s, imm:$o, 8)>; def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), - (CVT_s8_s64 Int64Regs:$s, CvtNONE)>; + (CVT_s8_s64 $s, CvtNONE)>; def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)), - (CVT_s8_s64 (BFE_S64rii Int64Regs:$s, imm:$o, 8), CvtNONE)>; + (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>; //----------------------------------- // Comparison instructions (setp, set) @@ -2032,47 +2032,47 @@ multiclass ISET_FORMAT<PatFrag OpNode, PatLeaf Mode, Instruction set_64ir> { // i16 -> pred def : Pat<(i1 (OpNode i16:$a, i16:$b)), - (setp_16rr Int16Regs:$a, Int16Regs:$b, Mode)>; + (setp_16rr $a, $b, Mode)>; def : Pat<(i1 (OpNode i16:$a, imm:$b)), - (setp_16ri Int16Regs:$a, imm:$b, Mode)>; + (setp_16ri $a, imm:$b, Mode)>; def : Pat<(i1 (OpNode imm:$a, i16:$b)), - (setp_16ir imm:$a, Int16Regs:$b, Mode)>; + (setp_16ir imm:$a, $b, Mode)>; // i32 -> pred def : Pat<(i1 (OpNode i32:$a, i32:$b)), - (setp_32rr Int32Regs:$a, Int32Regs:$b, Mode)>; + (setp_32rr $a, $b, Mode)>; def : Pat<(i1 (OpNode i32:$a, imm:$b)), - (setp_32ri Int32Regs:$a, imm:$b, Mode)>; + (setp_32ri $a, imm:$b, Mode)>; def : Pat<(i1 (OpNode imm:$a, i32:$b)), - (setp_32ir imm:$a, Int32Regs:$b, Mode)>; + (setp_32ir imm:$a, $b, Mode)>; // i64 -> pred def : Pat<(i1 (OpNode i64:$a, i64:$b)), - (setp_64rr Int64Regs:$a, Int64Regs:$b, Mode)>; + (setp_64rr $a, $b, Mode)>; def : Pat<(i1 (OpNode i64:$a, imm:$b)), - (setp_64ri Int64Regs:$a, imm:$b, Mode)>; + (setp_64ri $a, imm:$b, Mode)>; def : Pat<(i1 (OpNode imm:$a, i64:$b)), - (setp_64ir imm:$a, Int64Regs:$b, Mode)>; + (setp_64ir imm:$a, $b, Mode)>; // i16 -> i32 def : Pat<(i32 (OpNode i16:$a, i16:$b)), - (set_16rr Int16Regs:$a, Int16Regs:$b, Mode)>; + (set_16rr $a, $b, Mode)>; def : Pat<(i32 (OpNode i16:$a, imm:$b)), - (set_16ri Int16Regs:$a, imm:$b, Mode)>; + (set_16ri $a, imm:$b, Mode)>; def : Pat<(i32 (OpNode imm:$a, i16:$b)), - (set_16ir imm:$a, Int16Regs:$b, Mode)>; + (set_16ir imm:$a, $b, Mode)>; // i32 -> i32 def : Pat<(i32 (OpNode i32:$a, i32:$b)), - (set_32rr Int32Regs:$a, Int32Regs:$b, Mode)>; + (set_32rr $a, $b, Mode)>; def : Pat<(i32 (OpNode i32:$a, imm:$b)), - (set_32ri Int32Regs:$a, imm:$b, Mode)>; + (set_32ri $a, imm:$b, Mode)>; def : Pat<(i32 (OpNode imm:$a, i32:$b)), - (set_32ir imm:$a, Int32Regs:$b, Mode)>; + (set_32ir imm:$a, $b, Mode)>; // i64 -> i32 def : Pat<(i32 (OpNode i64:$a, Int64Regs:$b)), - (set_64rr Int64Regs:$a, Int64Regs:$b, Mode)>; + (set_64rr $a, $b, Mode)>; def : Pat<(i32 (OpNode i64:$a, imm:$b)), - (set_64ri Int64Regs:$a, imm:$b, Mode)>; + (set_64ri $a, imm:$b, Mode)>; def : Pat<(i32 (OpNode imm:$a, i64:$b)), - (set_64ir imm:$a, Int64Regs:$b, Mode)>; + (set_64ir imm:$a, $b, Mode)>; } multiclass ISET_FORMAT_SIGNED<PatFrag OpNode, PatLeaf Mode> @@ -2179,94 +2179,94 @@ def: Pat<(setne (i16 (and (trunc (bfe Int32Regs:$a, imm:$oa, 8)), 255)), // i1 compare -> i32 def : Pat<(i32 (setne i1:$a, i1:$b)), - (SELP_u32ii -1, 0, (XORb1rr Int1Regs:$a, Int1Regs:$b))>; + (SELP_u32ii -1, 0, (XORb1rr $a, $b))>; def : Pat<(i32 (setne i1:$a, i1:$b)), - (SELP_u32ii 0, -1, (XORb1rr Int1Regs:$a, Int1Regs:$b))>; + (SELP_u32ii 0, -1, (XORb1rr $a, $b))>; multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> { // f16 -> pred def : Pat<(i1 (OpNode f16:$a, f16:$b)), - (SETP_f16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>, + (SETP_f16rr $a, $b, ModeFTZ)>, Requires<[useFP16Math,doF32FTZ]>; def : Pat<(i1 (OpNode f16:$a, f16:$b)), - (SETP_f16rr Int16Regs:$a, Int16Regs:$b, Mode)>, + (SETP_f16rr $a, $b, Mode)>, Requires<[useFP16Math]>; // bf16 -> pred def : Pat<(i1 (OpNode bf16:$a, bf16:$b)), - (SETP_bf16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>, + (SETP_bf16rr $a, $b, ModeFTZ)>, Requires<[hasBF16Math,doF32FTZ]>; def : Pat<(i1 (OpNode bf16:$a, bf16:$b)), - (SETP_bf16rr Int16Regs:$a, Int16Regs:$b, Mode)>, + (SETP_bf16rr $a, $b, Mode)>, Requires<[hasBF16Math]>; // f32 -> pred def : Pat<(i1 (OpNode f32:$a, f32:$b)), - (SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, + (SETP_f32rr $a, $b, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(i1 (OpNode f32:$a, f32:$b)), - (SETP_f32rr Float32Regs:$a, Float32Regs:$b, Mode)>; + (SETP_f32rr $a, $b, Mode)>; def : Pat<(i1 (OpNode Float32Regs:$a, fpimm:$b)), - (SETP_f32ri Float32Regs:$a, fpimm:$b, ModeFTZ)>, + (SETP_f32ri $a, fpimm:$b, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(i1 (OpNode f32:$a, fpimm:$b)), - (SETP_f32ri Float32Regs:$a, fpimm:$b, Mode)>; + (SETP_f32ri $a, fpimm:$b, Mode)>; def : Pat<(i1 (OpNode fpimm:$a, f32:$b)), - (SETP_f32ir fpimm:$a, Float32Regs:$b, ModeFTZ)>, + (SETP_f32ir fpimm:$a, $b, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(i1 (OpNode fpimm:$a, f32:$b)), - (SETP_f32ir fpimm:$a, Float32Regs:$b, Mode)>; + (SETP_f32ir fpimm:$a, $b, Mode)>; // f64 -> pred def : Pat<(i1 (OpNode f64:$a, f64:$b)), - (SETP_f64rr Float64Regs:$a, Float64Regs:$b, Mode)>; + (SETP_f64rr $a, $b, Mode)>; def : Pat<(i1 (OpNode f64:$a, fpimm:$b)), - (SETP_f64ri Float64Regs:$a, fpimm:$b, Mode)>; + (SETP_f64ri $a, fpimm:$b, Mode)>; def : Pat<(i1 (OpNode fpimm:$a, f64:$b)), - (SETP_f64ir fpimm:$a, Float64Regs:$b, Mode)>; + (SETP_f64ir fpimm:$a, $b, Mode)>; // f16 -> i32 def : Pat<(i32 (OpNode f16:$a, f16:$b)), - (SET_f16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>, + (SET_f16rr $a, $b, ModeFTZ)>, Requires<[useFP16Math, doF32FTZ]>; def : Pat<(i32 (OpNode f16:$a, f16:$b)), - (SET_f16rr Int16Regs:$a, Int16Regs:$b, Mode)>, + (SET_f16rr $a, $b, Mode)>, Requires<[useFP16Math]>; // bf16 -> i32 def : Pat<(i32 (OpNode bf16:$a, bf16:$b)), - (SET_bf16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>, + (SET_bf16rr $a, $b, ModeFTZ)>, Requires<[hasBF16Math, doF32FTZ]>; def : Pat<(i32 (OpNode bf16:$a, bf16:$b)), - (SET_bf16rr Int16Regs:$a, Int16Regs:$b, Mode)>, + (SET_bf16rr $a, $b, Mode)>, Requires<[hasBF16Math]>; // f32 -> i32 def : Pat<(i32 (OpNode f32:$a, f32:$b)), - (SET_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, + (SET_f32rr $a, $b, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(i32 (OpNode f32:$a, f32:$b)), - (SET_f32rr Float32Regs:$a, Float32Regs:$b, Mode)>; + (SET_f32rr $a, $b, Mode)>; def : Pat<(i32 (OpNode f32:$a, fpimm:$b)), - (SET_f32ri Float32Regs:$a, fpimm:$b, ModeFTZ)>, + (SET_f32ri $a, fpimm:$b, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(i32 (OpNode f32:$a, fpimm:$b)), - (SET_f32ri Float32Regs:$a, fpimm:$b, Mode)>; + (SET_f32ri $a, fpimm:$b, Mode)>; def : Pat<(i32 (OpNode fpimm:$a, f32:$b)), - (SET_f32ir fpimm:$a, Float32Regs:$b, ModeFTZ)>, + (SET_f32ir fpimm:$a, $b, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(i32 (OpNode fpimm:$a, f32:$b)), - (SET_f32ir fpimm:$a, Float32Regs:$b, Mode)>; + (SET_f32ir fpimm:$a, $b, Mode)>; // f64 -> i32 def : Pat<(i32 (OpNode f64:$a, f64:$b)), - (SET_f64rr Float64Regs:$a, Float64Regs:$b, Mode)>; + (SET_f64rr $a, $b, Mode)>; def : Pat<(i32 (OpNode f64:$a, fpimm:$b)), - (SET_f64ri Float64Regs:$a, fpimm:$b, Mode)>; + (SET_f64ri $a, fpimm:$b, Mode)>; def : Pat<(i32 (OpNode fpimm:$a, f64:$b)), - (SET_f64ir fpimm:$a, Float64Regs:$b, Mode)>; + (SET_f64ir fpimm:$a, $b, Mode)>; } defm FSetOGT : FSET_FORMAT<setogt, CmpGT, CmpGT_FTZ>; @@ -2722,11 +2722,11 @@ def ProxyRegF32 : ProxyRegInst<"f32", f32, Float32Regs>; def ProxyRegF64 : ProxyRegInst<"f64", f64, Float64Regs>; foreach vt = [f16, bf16] in { - def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI16 Int16Regs:$src)>; + def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI16 $src)>; } foreach vt = [v2f16, v2bf16, v2i16, v4i8] in { - def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI32 Int32Regs:$src)>; + def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI32 $src)>; } // @@ -3029,9 +3029,9 @@ def BITCONVERT_64_F2I : F_BITCONVERT<"64", f64, i64>; foreach vt = [v2f16, v2bf16, v2i16, v4i8] in { def: Pat<(vt (bitconvert (f32 Float32Regs:$a))), - (BITCONVERT_32_F2I Float32Regs:$a)>; + (BITCONVERT_32_F2I $a)>; def: Pat<(f32 (bitconvert vt:$a)), - (BITCONVERT_32_I2F Int32Regs:$a)>; + (BITCONVERT_32_I2F $a)>; } foreach vt = [f16, bf16] in { def: Pat<(vt (bitconvert i16:$a)), @@ -3056,280 +3056,280 @@ foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in { // sint -> f16 def : Pat<(f16 (sint_to_fp i1:$a)), - (CVT_f16_s32 (SELP_s32ii -1, 0, Int1Regs:$a), CvtRN)>; + (CVT_f16_s32 (SELP_s32ii -1, 0, $a), CvtRN)>; def : Pat<(f16 (sint_to_fp Int16Regs:$a)), - (CVT_f16_s16 i16:$a, CvtRN)>; + (CVT_f16_s16 $a, CvtRN)>; def : Pat<(f16 (sint_to_fp i32:$a)), - (CVT_f16_s32 i32:$a, CvtRN)>; + (CVT_f16_s32 $a, CvtRN)>; def : Pat<(f16 (sint_to_fp i64:$a)), - (CVT_f16_s64 i64:$a, CvtRN)>; + (CVT_f16_s64 $a, CvtRN)>; // uint -> f16 def : Pat<(f16 (uint_to_fp i1:$a)), - (CVT_f16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; + (CVT_f16_u32 (SELP_u32ii 1, 0, $a), CvtRN)>; def : Pat<(f16 (uint_to_fp Int16Regs:$a)), - (CVT_f16_u16 i16:$a, CvtRN)>; + (CVT_f16_u16 $a, CvtRN)>; def : Pat<(f16 (uint_to_fp i32:$a)), - (CVT_f16_u32 i32:$a, CvtRN)>; + (CVT_f16_u32 $a, CvtRN)>; def : Pat<(f16 (uint_to_fp i64:$a)), - (CVT_f16_u64 i64:$a, CvtRN)>; + (CVT_f16_u64 $a, CvtRN)>; // sint -> bf16 def : Pat<(bf16 (sint_to_fp i1:$a)), - (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_s32 (SELP_u32ii 1, 0, $a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; def : Pat<(bf16 (sint_to_fp i16:$a)), - (CVT_bf16_s16 i16:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_s16 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; def : Pat<(bf16 (sint_to_fp i32:$a)), - (CVT_bf16_s32 i32:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_s32 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; def : Pat<(bf16 (sint_to_fp i64:$a)), - (CVT_bf16_s64 i64:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_s64 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; // uint -> bf16 def : Pat<(bf16 (uint_to_fp i1:$a)), - (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_u32 (SELP_u32ii 1, 0, $a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; def : Pat<(bf16 (uint_to_fp i16:$a)), - (CVT_bf16_u16 i16:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_u16 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; def : Pat<(bf16 (uint_to_fp i32:$a)), - (CVT_bf16_u32 i32:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_u32 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; def : Pat<(bf16 (uint_to_fp i64:$a)), - (CVT_bf16_u64 i64:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_u64 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; // sint -> f32 def : Pat<(f32 (sint_to_fp i1:$a)), - (CVT_f32_s32 (SELP_s32ii -1, 0, Int1Regs:$a), CvtRN)>; + (CVT_f32_s32 (SELP_s32ii -1, 0, $a), CvtRN)>; def : Pat<(f32 (sint_to_fp i16:$a)), - (CVT_f32_s16 i16:$a, CvtRN)>; + (CVT_f32_s16 $a, CvtRN)>; def : Pat<(f32 (sint_to_fp i32:$a)), - (CVT_f32_s32 i32:$a, CvtRN)>; + (CVT_f32_s32 $a, CvtRN)>; def : Pat<(f32 (sint_to_fp i64:$a)), - (CVT_f32_s64 i64:$a, CvtRN)>; + (CVT_f32_s64 $a, CvtRN)>; // uint -> f32 def : Pat<(f32 (uint_to_fp i1:$a)), - (CVT_f32_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; + (CVT_f32_u32 (SELP_u32ii 1, 0, $a), CvtRN)>; def : Pat<(f32 (uint_to_fp i16:$a)), - (CVT_f32_u16 Int16Regs:$a, CvtRN)>; + (CVT_f32_u16 $a, CvtRN)>; def : Pat<(f32 (uint_to_fp i32:$a)), - (CVT_f32_u32 i32:$a, CvtRN)>; + (CVT_f32_u32 $a, CvtRN)>; def : Pat<(f32 (uint_to_fp i64:$a)), - (CVT_f32_u64 i64:$a, CvtRN)>; + (CVT_f32_u64 $a, CvtRN)>; // sint -> f64 def : Pat<(f64 (sint_to_fp i1:$a)), - (CVT_f64_s32 (SELP_s32ii -1, 0, Int1Regs:$a), CvtRN)>; + (CVT_f64_s32 (SELP_s32ii -1, 0, $a), CvtRN)>; def : Pat<(f64 (sint_to_fp i16:$a)), - (CVT_f64_s16 Int16Regs:$a, CvtRN)>; + (CVT_f64_s16 $a, CvtRN)>; def : Pat<(f64 (sint_to_fp i32:$a)), - (CVT_f64_s32 i32:$a, CvtRN)>; + (CVT_f64_s32 $a, CvtRN)>; def : Pat<(f64 (sint_to_fp i64:$a)), - (CVT_f64_s64 i64:$a, CvtRN)>; + (CVT_f64_s64 $a, CvtRN)>; // uint -> f64 def : Pat<(f64 (uint_to_fp i1:$a)), - (CVT_f64_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; + (CVT_f64_u32 (SELP_u32ii 1, 0, $a), CvtRN)>; def : Pat<(f64 (uint_to_fp i16:$a)), - (CVT_f64_u16 Int16Regs:$a, CvtRN)>; + (CVT_f64_u16 $a, CvtRN)>; def : Pat<(f64 (uint_to_fp i32:$a)), - (CVT_f64_u32 i32:$a, CvtRN)>; + (CVT_f64_u32 $a, CvtRN)>; def : Pat<(f64 (uint_to_fp i64:$a)), - (CVT_f64_u64 i64:$a, CvtRN)>; + (CVT_f64_u64 $a, CvtRN)>; // f16 -> sint def : Pat<(i1 (fp_to_sint f16:$a)), - (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>; + (SETP_b16ri $a, 0, CmpEQ)>; def : Pat<(i16 (fp_to_sint f16:$a)), - (CVT_s16_f16 Int16Regs:$a, CvtRZI)>; + (CVT_s16_f16 $a, CvtRZI)>; def : Pat<(i32 (fp_to_sint f16:$a)), - (CVT_s32_f16 Int16Regs:$a, CvtRZI)>; + (CVT_s32_f16 $a, CvtRZI)>; def : Pat<(i64 (fp_to_sint f16:$a)), - (CVT_s64_f16 Int16Regs:$a, CvtRZI)>; + (CVT_s64_f16 $a, CvtRZI)>; // f16 -> uint def : Pat<(i1 (fp_to_uint f16:$a)), - (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>; + (SETP_b16ri $a, 0, CmpEQ)>; def : Pat<(i16 (fp_to_uint f16:$a)), - (CVT_u16_f16 Int16Regs:$a, CvtRZI)>; + (CVT_u16_f16 $a, CvtRZI)>; def : Pat<(i32 (fp_to_uint f16:$a)), - (CVT_u32_f16 Int16Regs:$a, CvtRZI)>; + (CVT_u32_f16 $a, CvtRZI)>; def : Pat<(i64 (fp_to_uint f16:$a)), - (CVT_u64_f16 Int16Regs:$a, CvtRZI)>; + (CVT_u64_f16 $a, CvtRZI)>; // bf16 -> sint def : Pat<(i1 (fp_to_sint bf16:$a)), - (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>; + (SETP_b16ri $a, 0, CmpEQ)>; def : Pat<(i16 (fp_to_sint bf16:$a)), - (CVT_s16_bf16 Int16Regs:$a, CvtRZI)>; + (CVT_s16_bf16 $a, CvtRZI)>; def : Pat<(i32 (fp_to_sint bf16:$a)), - (CVT_s32_bf16 Int16Regs:$a, CvtRZI)>; + (CVT_s32_bf16 $a, CvtRZI)>; def : Pat<(i64 (fp_to_sint bf16:$a)), - (CVT_s64_bf16 Int16Regs:$a, CvtRZI)>; + (CVT_s64_bf16 $a, CvtRZI)>; // bf16 -> uint def : Pat<(i1 (fp_to_uint bf16:$a)), - (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>; + (SETP_b16ri $a, 0, CmpEQ)>; def : Pat<(i16 (fp_to_uint bf16:$a)), - (CVT_u16_bf16 Int16Regs:$a, CvtRZI)>; + (CVT_u16_bf16 $a, CvtRZI)>; def : Pat<(i32 (fp_to_uint bf16:$a)), - (CVT_u32_bf16 Int16Regs:$a, CvtRZI)>; + (CVT_u32_bf16 $a, CvtRZI)>; def : Pat<(i64 (fp_to_uint bf16:$a)), - (CVT_u64_bf16 Int16Regs:$a, CvtRZI)>; + (CVT_u64_bf16 $a, CvtRZI)>; // f32 -> sint def : Pat<(i1 (fp_to_sint f32:$a)), - (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>; + (SETP_b32ri (BITCONVERT_32_F2I $a), 0, CmpEQ)>; def : Pat<(i16 (fp_to_sint f32:$a)), - (CVT_s16_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; + (CVT_s16_f32 $a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(i16 (fp_to_sint f32:$a)), - (CVT_s16_f32 Float32Regs:$a, CvtRZI)>; + (CVT_s16_f32 $a, CvtRZI)>; def : Pat<(i32 (fp_to_sint f32:$a)), - (CVT_s32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; + (CVT_s32_f32 $a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(i32 (fp_to_sint f32:$a)), - (CVT_s32_f32 Float32Regs:$a, CvtRZI)>; + (CVT_s32_f32 $a, CvtRZI)>; def : Pat<(i64 (fp_to_sint f32:$a)), - (CVT_s64_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; + (CVT_s64_f32 $a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(i64 (fp_to_sint f32:$a)), - (CVT_s64_f32 Float32Regs:$a, CvtRZI)>; + (CVT_s64_f32 $a, CvtRZI)>; // f32 -> uint def : Pat<(i1 (fp_to_uint f32:$a)), - (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>; + (SETP_b32ri (BITCONVERT_32_F2I $a), 0, CmpEQ)>; def : Pat<(i16 (fp_to_uint f32:$a)), - (CVT_u16_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; + (CVT_u16_f32 $a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(i16 (fp_to_uint f32:$a)), - (CVT_u16_f32 Float32Regs:$a, CvtRZI)>; + (CVT_u16_f32 $a, CvtRZI)>; def : Pat<(i32 (fp_to_uint f32:$a)), - (CVT_u32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; + (CVT_u32_f32 $a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(i32 (fp_to_uint f32:$a)), - (CVT_u32_f32 Float32Regs:$a, CvtRZI)>; + (CVT_u32_f32 $a, CvtRZI)>; def : Pat<(i64 (fp_to_uint f32:$a)), - (CVT_u64_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; + (CVT_u64_f32 $a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(i64 (fp_to_uint f32:$a)), - (CVT_u64_f32 Float32Regs:$a, CvtRZI)>; + (CVT_u64_f32 $a, CvtRZI)>; // f64 -> sint def : Pat<(i1 (fp_to_sint f64:$a)), - (SETP_b64ri (BITCONVERT_64_F2I Float64Regs:$a), 0, CmpEQ)>; + (SETP_b64ri (BITCONVERT_64_F2I $a), 0, CmpEQ)>; def : Pat<(i16 (fp_to_sint f64:$a)), - (CVT_s16_f64 Float64Regs:$a, CvtRZI)>; + (CVT_s16_f64 $a, CvtRZI)>; def : Pat<(i32 (fp_to_sint f64:$a)), - (CVT_s32_f64 Float64Regs:$a, CvtRZI)>; + (CVT_s32_f64 $a, CvtRZI)>; def : Pat<(i64 (fp_to_sint f64:$a)), - (CVT_s64_f64 Float64Regs:$a, CvtRZI)>; + (CVT_s64_f64 $a, CvtRZI)>; // f64 -> uint def : Pat<(i1 (fp_to_uint f64:$a)), - (SETP_b64ri (BITCONVERT_64_F2I Float64Regs:$a), 0, CmpEQ)>; + (SETP_b64ri (BITCONVERT_64_F2I $a), 0, CmpEQ)>; def : Pat<(i16 (fp_to_uint f64:$a)), - (CVT_u16_f64 Float64Regs:$a, CvtRZI)>; + (CVT_u16_f64 $a, CvtRZI)>; def : Pat<(i32 (fp_to_uint f64:$a)), - (CVT_u32_f64 Float64Regs:$a, CvtRZI)>; + (CVT_u32_f64 $a, CvtRZI)>; def : Pat<(i64 (fp_to_uint f64:$a)), - (CVT_u64_f64 Float64Regs:$a, CvtRZI)>; + (CVT_u64_f64 $a, CvtRZI)>; // sext i1 def : Pat<(i16 (sext i1:$a)), - (SELP_s16ii -1, 0, Int1Regs:$a)>; + (SELP_s16ii -1, 0, $a)>; def : Pat<(i32 (sext i1:$a)), - (SELP_s32ii -1, 0, Int1Regs:$a)>; + (SELP_s32ii -1, 0, $a)>; def : Pat<(i64 (sext i1:$a)), - (SELP_s64ii -1, 0, Int1Regs:$a)>; + (SELP_s64ii -1, 0, $a)>; // zext i1 def : Pat<(i16 (zext i1:$a)), - (SELP_u16ii 1, 0, Int1Regs:$a)>; + (SELP_u16ii 1, 0, $a)>; def : Pat<(i32 (zext i1:$a)), - (SELP_u32ii 1, 0, Int1Regs:$a)>; + (SELP_u32ii 1, 0, $a)>; def : Pat<(i64 (zext i1:$a)), - (SELP_u64ii 1, 0, Int1Regs:$a)>; + (SELP_u64ii 1, 0, $a)>; // anyext i1 def : Pat<(i16 (anyext i1:$a)), - (SELP_u16ii -1, 0, Int1Regs:$a)>; + (SELP_u16ii -1, 0, $a)>; def : Pat<(i32 (anyext i1:$a)), - (SELP_u32ii -1, 0, Int1Regs:$a)>; + (SELP_u32ii -1, 0, $a)>; def : Pat<(i64 (anyext i1:$a)), - (SELP_u64ii -1, 0, Int1Regs:$a)>; + (SELP_u64ii -1, 0, $a)>; // sext i16 def : Pat<(i32 (sext i16:$a)), - (CVT_s32_s16 Int16Regs:$a, CvtNONE)>; + (CVT_s32_s16 $a, CvtNONE)>; def : Pat<(i64 (sext i16:$a)), - (CVT_s64_s16 Int16Regs:$a, CvtNONE)>; + (CVT_s64_s16 $a, CvtNONE)>; // zext i16 def : Pat<(i32 (zext i16:$a)), - (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; + (CVT_u32_u16 $a, CvtNONE)>; def : Pat<(i64 (zext i16:$a)), - (CVT_u64_u16 Int16Regs:$a, CvtNONE)>; + (CVT_u64_u16 $a, CvtNONE)>; // anyext i16 def : Pat<(i32 (anyext i16:$a)), - (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; + (CVT_u32_u16 $a, CvtNONE)>; def : Pat<(i64 (anyext i16:$a)), - (CVT_u64_u16 Int16Regs:$a, CvtNONE)>; + (CVT_u64_u16 $a, CvtNONE)>; // sext i32 def : Pat<(i64 (sext i32:$a)), - (CVT_s64_s32 Int32Regs:$a, CvtNONE)>; + (CVT_s64_s32 $a, CvtNONE)>; // zext i32 def : Pat<(i64 (zext i32:$a)), - (CVT_u64_u32 Int32Regs:$a, CvtNONE)>; + (CVT_u64_u32 $a, CvtNONE)>; // anyext i32 def : Pat<(i64 (anyext i32:$a)), - (CVT_u64_u32 Int32Regs:$a, CvtNONE)>; + (CVT_u64_u32 $a, CvtNONE)>; // truncate i64 def : Pat<(i32 (trunc i64:$a)), - (CVT_u32_u64 Int64Regs:$a, CvtNONE)>; + (CVT_u32_u64 $a, CvtNONE)>; def : Pat<(i16 (trunc i64:$a)), - (CVT_u16_u64 Int64Regs:$a, CvtNONE)>; + (CVT_u16_u64 $a, CvtNONE)>; def : Pat<(i1 (trunc i64:$a)), - (SETP_b64ri (ANDb64ri Int64Regs:$a, 1), 1, CmpEQ)>; + (SETP_b64ri (ANDb64ri $a, 1), 1, CmpEQ)>; // truncate i32 def : Pat<(i16 (trunc i32:$a)), - (CVT_u16_u32 Int32Regs:$a, CvtNONE)>; + (CVT_u16_u32 $a, CvtNONE)>; def : Pat<(i1 (trunc i32:$a)), - (SETP_b32ri (ANDb32ri Int32Regs:$a, 1), 1, CmpEQ)>; + (SETP_b32ri (ANDb32ri $a, 1), 1, CmpEQ)>; // truncate i16 def : Pat<(i1 (trunc i16:$a)), - (SETP_b16ri (ANDb16ri Int16Regs:$a, 1), 1, CmpEQ)>; + (SETP_b16ri (ANDb16ri $a, 1), 1, CmpEQ)>; // sext_inreg -def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 Int16Regs:$a)>; -def : Pat<(sext_inreg i32:$a, i8), (CVT_INREG_s32_s8 Int32Regs:$a)>; -def : Pat<(sext_inreg i32:$a, i16), (CVT_INREG_s32_s16 Int32Regs:$a)>; -def : Pat<(sext_inreg i64:$a, i8), (CVT_INREG_s64_s8 Int64Regs:$a)>; -def : Pat<(sext_inreg i64:$a, i16), (CVT_INREG_s64_s16 Int64Regs:$a)>; -def : Pat<(sext_inreg i64:$a, i32), (CVT_INREG_s64_s32 Int64Regs:$a)>; +def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>; +def : Pat<(sext_inreg i32:$a, i8), (CVT_INREG_s32_s8 $a)>; +def : Pat<(sext_inreg i32:$a, i16), (CVT_INREG_s32_s16 $a)>; +def : Pat<(sext_inreg i64:$a, i8), (CVT_INREG_s64_s8 $a)>; +def : Pat<(sext_inreg i64:$a, i16), (CVT_INREG_s64_s16 $a)>; +def : Pat<(sext_inreg i64:$a, i32), (CVT_INREG_s64_s32 $a)>; // Select instructions with 32-bit predicates def : Pat<(select i32:$pred, i16:$a, i16:$b), - (SELP_b16rr Int16Regs:$a, Int16Regs:$b, - (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; + (SELP_b16rr $a, $b, + (SETP_b32ri (ANDb32ri $pred, 1), 1, CmpEQ))>; def : Pat<(select i32:$pred, i32:$a, i32:$b), - (SELP_b32rr Int32Regs:$a, Int32Regs:$b, - (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; + (SELP_b32rr $a, $b, + (SETP_b32ri (ANDb32ri $pred, 1), 1, CmpEQ))>; def : Pat<(select i32:$pred, i64:$a, i64:$b), - (SELP_b64rr Int64Regs:$a, Int64Regs:$b, - (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; + (SELP_b64rr $a, $b, + (SETP_b32ri (ANDb32ri $pred, 1), 1, CmpEQ))>; def : Pat<(select i32:$pred, f16:$a, f16:$b), - (SELP_f16rr Int16Regs:$a, Int16Regs:$b, - (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; + (SELP_f16rr $a, $b, + (SETP_b32ri (ANDb32ri $pred, 1), 1, CmpEQ))>; def : Pat<(select i32:$pred, bf16:$a, bf16:$b), - (SELP_bf16rr Int16Regs:$a, Int16Regs:$b, - (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; + (SELP_bf16rr $a, $b, + (SETP_b32ri (ANDb32ri $pred, 1), 1, CmpEQ))>; def : Pat<(select i32:$pred, f32:$a, f32:$b), - (SELP_f32rr Float32Regs:$a, Float32Regs:$b, - (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; + (SELP_f32rr $a, $b, + (SETP_b32ri (ANDb32ri $pred, 1), 1, CmpEQ))>; def : Pat<(select i32:$pred, f64:$a, f64:$b), - (SELP_f64rr Float64Regs:$a, Float64Regs:$b, - (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; + (SELP_f64rr $a, $b, + (SETP_b32ri (ANDb32ri $pred, 1), 1, CmpEQ))>; let hasSideEffects = false in { @@ -3391,32 +3391,32 @@ let hasSideEffects = false in { // Using partial vectorized move produces better SASS code for extraction of // upper/lower parts of an integer. def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), - (I32toI16H Int32Regs:$s)>; + (I32toI16H $s)>; def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), - (I32toI16H Int32Regs:$s)>; + (I32toI16H $s)>; def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), - (I64toI32H Int64Regs:$s)>; + (I64toI32H $s)>; def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), - (I64toI32H Int64Regs:$s)>; + (I64toI32H $s)>; def: Pat<(i32 (sext (extractelt v2i16:$src, 0))), - (CVT_INREG_s32_s16 Int32Regs:$src)>; + (CVT_INREG_s32_s16 $src)>; foreach vt = [v2f16, v2bf16, v2i16] in { def : Pat<(extractelt vt:$src, 0), - (I32toI16L Int32Regs:$src)>; + (I32toI16L $src)>; def : Pat<(extractelt vt:$src, 1), - (I32toI16H Int32Regs:$src)>; + (I32toI16H $src)>; } def : Pat<(v2f16 (build_vector f16:$a, f16:$b)), - (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>; + (V2I16toI32 $a, $b)>; def : Pat<(v2bf16 (build_vector bf16:$a, bf16:$b)), - (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>; + (V2I16toI32 $a, $b)>; def : Pat<(v2i16 (build_vector i16:$a, i16:$b)), - (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>; + (V2I16toI32 $a, $b)>; def: Pat<(v2i16 (scalar_to_vector i16:$a)), - (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; + (CVT_u32_u16 $a, CvtNONE)>; // // Funnel-Shift @@ -3455,13 +3455,13 @@ let hasSideEffects = false in { } def : Pat<(i32 (int_nvvm_fshl_clamp i32:$hi, i32:$lo, i32:$amt)), - (SHF_L_CLAMP_r Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt)>; + (SHF_L_CLAMP_r $lo, $hi, $amt)>; def : Pat<(i32 (int_nvvm_fshl_clamp i32:$hi, i32:$lo, (i32 imm:$amt))), - (SHF_L_CLAMP_i Int32Regs:$lo, Int32Regs:$hi, imm:$amt)>; + (SHF_L_CLAMP_i $lo, $hi, imm:$amt)>; def : Pat<(i32 (int_nvvm_fshr_clamp i32:$hi, i32:$lo, i32:$amt)), - (SHF_R_CLAMP_r Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt)>; + (SHF_R_CLAMP_r $lo, $hi, $amt)>; def : Pat<(i32 (int_nvvm_fshr_clamp i32:$hi, i32:$lo, (i32 imm:$amt))), - (SHF_R_CLAMP_i Int32Regs:$lo, Int32Regs:$hi, imm:$amt)>; + (SHF_R_CLAMP_i $lo, $hi, imm:$amt)>; // Count leading zeros let hasSideEffects = false in { @@ -3472,14 +3472,14 @@ let hasSideEffects = false in { } // 32-bit has a direct PTX instruction -def : Pat<(i32 (ctlz i32:$a)), (CLZr32 i32:$a)>; +def : Pat<(i32 (ctlz i32:$a)), (CLZr32 $a)>; // The return type of the ctlz ISD node is the same as its input, but the PTX // ctz instruction always returns a 32-bit value. For ctlz.i64, convert the // ptx value to 64 bits to match the ISD node's semantics, unless we know we're // truncating back down to 32 bits. -def : Pat<(i64 (ctlz i64:$a)), (CVT_u64_u32 (CLZr64 Int64Regs:$a), CvtNONE)>; -def : Pat<(i32 (trunc (i64 (ctlz i64:$a)))), (CLZr64 Int64Regs:$a)>; +def : Pat<(i64 (ctlz i64:$a)), (CVT_u64_u32 (CLZr64 $a), CvtNONE)>; +def : Pat<(i32 (trunc (i64 (ctlz i64:$a)))), (CLZr64 $a)>; // For 16-bit ctlz, we zero-extend to 32-bit, perform the count, then trunc the // result back to 16-bits if necessary. We also need to subtract 16 because @@ -3497,9 +3497,9 @@ def : Pat<(i32 (trunc (i64 (ctlz i64:$a)))), (CLZr64 Int64Regs:$a)>; // "mov b32reg, {b16imm, b16reg}", so we don't do this optimization. def : Pat<(i16 (ctlz i16:$a)), (SUBi16ri (CVT_u16_u32 - (CLZr32 (CVT_u32_u16 Int16Regs:$a, CvtNONE)), CvtNONE), 16)>; + (CLZr32 (CVT_u32_u16 $a, CvtNONE)), CvtNONE), 16)>; def : Pat<(i32 (zext (i16 (ctlz i16:$a)))), - (SUBi32ri (CLZr32 (CVT_u32_u16 Int16Regs:$a, CvtNONE)), 16)>; + (SUBi32ri (CLZr32 (CVT_u32_u16 $a, CvtNONE)), 16)>; // Population count let hasSideEffects = false in { @@ -3510,67 +3510,67 @@ let hasSideEffects = false in { } // 32-bit has a direct PTX instruction -def : Pat<(i32 (ctpop i32:$a)), (POPCr32 Int32Regs:$a)>; +def : Pat<(i32 (ctpop i32:$a)), (POPCr32 $a)>; // For 64-bit, the result in PTX is actually 32-bit so we zero-extend to 64-bit // to match the LLVM semantics. Just as with ctlz.i64, we provide a second // pattern that avoids the type conversion if we're truncating the result to // i32 anyway. -def : Pat<(ctpop i64:$a), (CVT_u64_u32 (POPCr64 Int64Regs:$a), CvtNONE)>; -def : Pat<(i32 (trunc (i64 (ctpop i64:$a)))), (POPCr64 Int64Regs:$a)>; +def : Pat<(ctpop i64:$a), (CVT_u64_u32 (POPCr64 $a), CvtNONE)>; +def : Pat<(i32 (trunc (i64 (ctpop i64:$a)))), (POPCr64 $a)>; // For 16-bit, we zero-extend to 32-bit, then trunc the result back to 16-bits. // If we know that we're storing into an i32, we can avoid the final trunc. def : Pat<(ctpop i16:$a), - (CVT_u16_u32 (POPCr32 (CVT_u32_u16 Int16Regs:$a, CvtNONE)), CvtNONE)>; + (CVT_u16_u32 (POPCr32 (CVT_u32_u16 $a, CvtNONE)), CvtNONE)>; def : Pat<(i32 (zext (i16 (ctpop i16:$a)))), - (POPCr32 (CVT_u32_u16 Int16Regs:$a, CvtNONE))>; + (POPCr32 (CVT_u32_u16 $a, CvtNONE))>; // fpround f32 -> f16 def : Pat<(f16 (fpround f32:$a)), - (CVT_f16_f32 Float32Regs:$a, CvtRN)>; + (CVT_f16_f32 $a, CvtRN)>; // fpround f32 -> bf16 def : Pat<(bf16 (fpround f32:$a)), - (CVT_bf16_f32 Float32Regs:$a, CvtRN)>, Requires<[hasPTX<70>, hasSM<80>]>; + (CVT_bf16_f32 $a, CvtRN)>, Requires<[hasPTX<70>, hasSM<80>]>; // fpround f64 -> f16 def : Pat<(f16 (fpround f64:$a)), - (CVT_f16_f64 Float64Regs:$a, CvtRN)>; + (CVT_f16_f64 $a, CvtRN)>; // fpround f64 -> bf16 def : Pat<(bf16 (fpround f64:$a)), - (CVT_bf16_f64 Float64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_bf16_f64 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>; // fpround f64 -> f32 def : Pat<(f32 (fpround f64:$a)), - (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; + (CVT_f32_f64 $a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fpround f64:$a)), - (CVT_f32_f64 Float64Regs:$a, CvtRN)>; + (CVT_f32_f64 $a, CvtRN)>; // fpextend f16 -> f32 def : Pat<(f32 (fpextend f16:$a)), - (CVT_f32_f16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; + (CVT_f32_f16 $a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fpextend f16:$a)), - (CVT_f32_f16 Int16Regs:$a, CvtNONE)>; + (CVT_f32_f16 $a, CvtNONE)>; // fpextend bf16 -> f32 def : Pat<(f32 (fpextend bf16:$a)), - (CVT_f32_bf16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; + (CVT_f32_bf16 $a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fpextend bf16:$a)), - (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<71>, hasSM<80>]>; + (CVT_f32_bf16 $a, CvtNONE)>, Requires<[hasPTX<71>, hasSM<80>]>; // fpextend f16 -> f64 def : Pat<(f64 (fpextend f16:$a)), - (CVT_f64_f16 Int16Regs:$a, CvtNONE)>; + (CVT_f64_f16 $a, CvtNONE)>; // fpextend bf16 -> f64 def : Pat<(f64 (fpextend bf16:$a)), - (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<78>, hasSM<90>]>; + (CVT_f64_bf16 $a, CvtNONE)>, Requires<[hasPTX<78>, hasSM<90>]>; // fpextend f32 -> f64 def : Pat<(f64 (fpextend f32:$a)), - (CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; + (CVT_f64_f32 $a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f64 (fpextend f32:$a)), - (CVT_f64_f32 Float32Regs:$a, CvtNONE)>; + (CVT_f64_f32 $a, CvtNONE)>; def retglue : SDNode<"NVPTXISD::RET_GLUE", SDTNone, [SDNPHasChain, SDNPOptInGlue]>; @@ -3579,15 +3579,15 @@ def retglue : SDNode<"NVPTXISD::RET_GLUE", SDTNone, multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> { def : Pat<(OpNode f16:$a), - (CVT_f16_f16 Int16Regs:$a, Mode)>; + (CVT_f16_f16 $a, Mode)>; def : Pat<(OpNode bf16:$a), - (CVT_bf16_bf16 Int16Regs:$a, Mode)>; + (CVT_bf16_bf16 $a, Mode)>; def : Pat<(OpNode f32:$a), - (CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>; + (CVT_f32_f32 $a, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(OpNode f32:$a), - (CVT_f32_f32 Float32Regs:$a, Mode)>, Requires<[doNoF32FTZ]>; + (CVT_f32_f32 $a, Mode)>, Requires<[doNoF32FTZ]>; def : Pat<(OpNode f64:$a), - (CVT_f64_f64 Float64Regs:$a, Mode)>; + (CVT_f64_f64 $a, Mode)>; } defm : CVT_ROUND<fceil, CvtRPI, CvtRPI_FTZ>; @@ -3624,7 +3624,7 @@ let isTerminator=1 in { } def : Pat<(brcond i32:$a, bb:$target), - (CBranch (SETP_u32ri Int32Regs:$a, 0, CmpNE), bb:$target)>; + (CBranch (SETP_u32ri $a, 0, CmpNE), bb:$target)>; // SelectionDAGBuilder::visitSWitchCase() will invert the condition of a // conditional branch if the target block is the next block so that the code @@ -3632,7 +3632,7 @@ def : Pat<(brcond i32:$a, bb:$target), // condition, 1', which will be translated to (setne condition, -1). Since ptx // supports '@!pred bra target', we should use it. def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target), - (CBranchOther i1:$a, bb:$target)>; + (CBranchOther $a, bb:$target)>; // Call def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>, @@ -3830,17 +3830,17 @@ include "NVPTXIntrinsics.td" def : Pat < (i32 (bswap i32:$a)), - (INT_NVVM_PRMT Int32Regs:$a, (i32 0), (i32 0x0123))>; + (INT_NVVM_PRMT $a, (i32 0), (i32 0x0123))>; def : Pat < (v2i16 (bswap v2i16:$a)), - (INT_NVVM_PRMT Int32Regs:$a, (i32 0), (i32 0x2301))>; + (INT_NVVM_PRMT $a, (i32 0), (i32 0x2301))>; def : Pat < (i64 (bswap i64:$a)), (V2I32toI64 - (INT_NVVM_PRMT (I64toI32H Int64Regs:$a), (i32 0), (i32 0x0123)), - (INT_NVVM_PRMT (I64toI32L Int64Regs:$a), (i32 0), (i32 0x0123)))>; + (INT_NVVM_PRMT (I64toI32H $a), (i32 0), (i32 0x0123)), + (INT_NVVM_PRMT (I64toI32L $a), (i32 0), (i32 0x0123)))>; //////////////////////////////////////////////////////////////////////////////// @@ -3910,18 +3910,18 @@ def FMARELU_BF16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.bf16x2", [hasBF16Math // FTZ def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan f16:$a, f16:$b, f16:$c), fpimm_any_zero)), - (FMARELU_F16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>, + (FMARELU_F16_FTZ $a, $b, $c)>, Requires<[doF32FTZ]>; def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan v2f16:$a, v2f16:$b, v2f16:$c), fpimm_positive_zero_v2f16)), - (FMARELU_F16X2_FTZ Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>, + (FMARELU_F16X2_FTZ $a, $b, $c)>, Requires<[doF32FTZ]>; // NO FTZ def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan f16:$a, f16:$b, f16:$c), fpimm_any_zero)), - (FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>; + (FMARELU_F16 $a, $b, $c)>; def : Pat<(bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan bf16:$a, bf16:$b, bf16:$c), fpimm_any_zero)), - (FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>; + (FMARELU_BF16 $a, $b, $c)>; def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan v2f16:$a, v2f16:$b, v2f16:$c), fpimm_positive_zero_v2f16)), - (FMARELU_F16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>; + (FMARELU_F16X2 $a, $b, $c)>; def : Pat<(v2bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan v2bf16:$a, v2bf16:$b, v2bf16:$c), fpimm_positive_zero_v2bf16)), - (FMARELU_BF16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>; + (FMARELU_BF16X2 $a, $b, $c)>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 0773c1b..8ede1ec 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -824,29 +824,29 @@ def MBARRIER_PENDING_COUNT : def : Pat<(int_nvvm_fmin_f immFloat1, (int_nvvm_fmax_f immFloat0, f32:$a)), - (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; + (CVT_f32_f32 $a, CvtSAT)>; def : Pat<(int_nvvm_fmin_f immFloat1, (int_nvvm_fmax_f f32:$a, immFloat0)), - (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; + (CVT_f32_f32 $a, CvtSAT)>; def : Pat<(int_nvvm_fmin_f (int_nvvm_fmax_f immFloat0, f32:$a), immFloat1), - (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; + (CVT_f32_f32 $a, CvtSAT)>; def : Pat<(int_nvvm_fmin_f (int_nvvm_fmax_f f32:$a, immFloat0), immFloat1), - (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; + (CVT_f32_f32 $a, CvtSAT)>; def : Pat<(int_nvvm_fmin_d immDouble1, (int_nvvm_fmax_d immDouble0, f64:$a)), - (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; + (CVT_f64_f64 $a, CvtSAT)>; def : Pat<(int_nvvm_fmin_d immDouble1, (int_nvvm_fmax_d f64:$a, immDouble0)), - (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; + (CVT_f64_f64 $a, CvtSAT)>; def : Pat<(int_nvvm_fmin_d (int_nvvm_fmax_d immDouble0, f64:$a), immDouble1), - (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; + (CVT_f64_f64 $a, CvtSAT)>; def : Pat<(int_nvvm_fmin_d (int_nvvm_fmax_d f64:$a, immDouble0), immDouble1), - (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; + (CVT_f64_f64 $a, CvtSAT)>; // We need a full string for OpcStr here because we need to deal with case like @@ -1125,16 +1125,16 @@ def INT_NVVM_DIV_RP_D : F_MATH_2<"div.rp.f64 \t$dst, $src0, $src1;", Float64Regs, Float64Regs, Float64Regs, int_nvvm_div_rp_d>; def : Pat<(int_nvvm_div_full f32:$a, f32:$b), - (FDIV32rr Float32Regs:$a, Float32Regs:$b)>; + (FDIV32rr $a, $b)>; def : Pat<(int_nvvm_div_full f32:$a, fpimm:$b), - (FDIV32ri Float32Regs:$a, f32imm:$b)>; + (FDIV32ri $a, f32imm:$b)>; def : Pat<(int_nvvm_div_full_ftz f32:$a, f32:$b), - (FDIV32rr_ftz Float32Regs:$a, Float32Regs:$b)>; + (FDIV32rr_ftz $a, $b)>; def : Pat<(int_nvvm_div_full_ftz f32:$a, fpimm:$b), - (FDIV32ri_ftz Float32Regs:$a, f32imm:$b)>; + (FDIV32ri_ftz $a, f32imm:$b)>; // // Sad @@ -1158,18 +1158,18 @@ def INT_NVVM_SAD_ULL : F_MATH_3<"sad.u64 \t$dst, $src0, $src1, $src2;", // def : Pat<(int_nvvm_floor_ftz_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRMI_FTZ)>; + (CVT_f32_f32 $a, CvtRMI_FTZ)>; def : Pat<(int_nvvm_floor_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRMI)>; + (CVT_f32_f32 $a, CvtRMI)>; def : Pat<(int_nvvm_floor_d f64:$a), - (CVT_f64_f64 Float64Regs:$a, CvtRMI)>; + (CVT_f64_f64 $a, CvtRMI)>; def : Pat<(int_nvvm_ceil_ftz_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRPI_FTZ)>; + (CVT_f32_f32 $a, CvtRPI_FTZ)>; def : Pat<(int_nvvm_ceil_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRPI)>; + (CVT_f32_f32 $a, CvtRPI)>; def : Pat<(int_nvvm_ceil_d f64:$a), - (CVT_f64_f64 Float64Regs:$a, CvtRPI)>; + (CVT_f64_f64 $a, CvtRPI)>; // // Abs @@ -1217,33 +1217,33 @@ def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", Int32Regs, // def : Pat<(int_nvvm_round_ftz_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>; + (CVT_f32_f32 $a, CvtRNI_FTZ)>; def : Pat<(int_nvvm_round_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRNI)>; + (CVT_f32_f32 $a, CvtRNI)>; def : Pat<(int_nvvm_round_d f64:$a), - (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; + (CVT_f64_f64 $a, CvtRNI)>; // // Trunc // def : Pat<(int_nvvm_trunc_ftz_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRZI_FTZ)>; + (CVT_f32_f32 $a, CvtRZI_FTZ)>; def : Pat<(int_nvvm_trunc_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRZI)>; + (CVT_f32_f32 $a, CvtRZI)>; def : Pat<(int_nvvm_trunc_d f64:$a), - (CVT_f64_f64 Float64Regs:$a, CvtRZI)>; + (CVT_f64_f64 $a, CvtRZI)>; // // Saturate // def : Pat<(int_nvvm_saturate_ftz_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtSAT_FTZ)>; + (CVT_f32_f32 $a, CvtSAT_FTZ)>; def : Pat<(int_nvvm_saturate_f f32:$a), - (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; + (CVT_f32_f32 $a, CvtSAT)>; def : Pat<(int_nvvm_saturate_d f64:$a), - (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; + (CVT_f64_f64 $a, CvtSAT)>; // // Exp2 Log2 @@ -1430,13 +1430,13 @@ def INT_NVVM_SQRT_RP_D : F_MATH_1<"sqrt.rp.f64 \t$dst, $src0;", Float64Regs, // nvvm_sqrt intrinsic def : Pat<(int_nvvm_sqrt_f f32:$a), - (INT_NVVM_SQRT_RN_FTZ_F Float32Regs:$a)>, Requires<[doF32FTZ, do_SQRTF32_RN]>; + (INT_NVVM_SQRT_RN_FTZ_F $a)>, Requires<[doF32FTZ, do_SQRTF32_RN]>; def : Pat<(int_nvvm_sqrt_f f32:$a), - (INT_NVVM_SQRT_RN_F Float32Regs:$a)>, Requires<[do_SQRTF32_RN]>; + (INT_NVVM_SQRT_RN_F $a)>, Requires<[do_SQRTF32_RN]>; def : Pat<(int_nvvm_sqrt_f f32:$a), - (INT_NVVM_SQRT_APPROX_FTZ_F Float32Regs:$a)>, Requires<[doF32FTZ]>; + (INT_NVVM_SQRT_APPROX_FTZ_F $a)>, Requires<[doF32FTZ]>; def : Pat<(int_nvvm_sqrt_f f32:$a), - (INT_NVVM_SQRT_APPROX_F Float32Regs:$a)>; + (INT_NVVM_SQRT_APPROX_F $a)>; // // Rsqrt @@ -1456,24 +1456,24 @@ def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;", // 1.0f / sqrt_approx -> rsqrt_approx def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f f32:$a)), - (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>, + (INT_NVVM_RSQRT_APPROX_F $a)>, Requires<[doRsqrtOpt]>; def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f f32:$a)), - (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>, + (INT_NVVM_RSQRT_APPROX_FTZ_F $a)>, Requires<[doRsqrtOpt]>; // same for int_nvvm_sqrt_f when non-precision sqrt is requested def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f f32:$a)), - (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>, + (INT_NVVM_RSQRT_APPROX_F $a)>, Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>; def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f f32:$a)), - (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>, + (INT_NVVM_RSQRT_APPROX_FTZ_F $a)>, Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>; def: Pat<(fdiv FloatConst1, (fsqrt f32:$a)), - (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>, + (INT_NVVM_RSQRT_APPROX_F $a)>, Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>; def: Pat<(fdiv FloatConst1, (fsqrt f32:$a)), - (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>, + (INT_NVVM_RSQRT_APPROX_FTZ_F $a)>, Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>; // // Add @@ -1529,136 +1529,136 @@ foreach t = [I32RT, I64RT] in { // def : Pat<(int_nvvm_d2f_rn_ftz f64:$a), - (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>; + (CVT_f32_f64 $a, CvtRN_FTZ)>; def : Pat<(int_nvvm_d2f_rn f64:$a), - (CVT_f32_f64 Float64Regs:$a, CvtRN)>; + (CVT_f32_f64 $a, CvtRN)>; def : Pat<(int_nvvm_d2f_rz_ftz f64:$a), - (CVT_f32_f64 Float64Regs:$a, CvtRZ_FTZ)>; + (CVT_f32_f64 $a, CvtRZ_FTZ)>; def : Pat<(int_nvvm_d2f_rz f64:$a), - (CVT_f32_f64 Float64Regs:$a, CvtRZ)>; + (CVT_f32_f64 $a, CvtRZ)>; def : Pat<(int_nvvm_d2f_rm_ftz f64:$a), - (CVT_f32_f64 Float64Regs:$a, CvtRM_FTZ)>; + (CVT_f32_f64 $a, CvtRM_FTZ)>; def : Pat<(int_nvvm_d2f_rm f64:$a), - (CVT_f32_f64 Float64Regs:$a, CvtRM)>; + (CVT_f32_f64 $a, CvtRM)>; def : Pat<(int_nvvm_d2f_rp_ftz f64:$a), - (CVT_f32_f64 Float64Regs:$a, CvtRP_FTZ)>; + (CVT_f32_f64 $a, CvtRP_FTZ)>; def : Pat<(int_nvvm_d2f_rp f64:$a), - (CVT_f32_f64 Float64Regs:$a, CvtRP)>; + (CVT_f32_f64 $a, CvtRP)>; def : Pat<(int_nvvm_d2i_rn f64:$a), - (CVT_s32_f64 Float64Regs:$a, CvtRNI)>; + (CVT_s32_f64 $a, CvtRNI)>; def : Pat<(int_nvvm_d2i_rz f64:$a), - (CVT_s32_f64 Float64Regs:$a, CvtRZI)>; + (CVT_s32_f64 $a, CvtRZI)>; def : Pat<(int_nvvm_d2i_rm f64:$a), - (CVT_s32_f64 Float64Regs:$a, CvtRMI)>; + (CVT_s32_f64 $a, CvtRMI)>; def : Pat<(int_nvvm_d2i_rp f64:$a), - (CVT_s32_f64 Float64Regs:$a, CvtRPI)>; + (CVT_s32_f64 $a, CvtRPI)>; def : Pat<(int_nvvm_d2ui_rn f64:$a), - (CVT_u32_f64 Float64Regs:$a, CvtRNI)>; + (CVT_u32_f64 $a, CvtRNI)>; def : Pat<(int_nvvm_d2ui_rz f64:$a), - (CVT_u32_f64 Float64Regs:$a, CvtRZI)>; + (CVT_u32_f64 $a, CvtRZI)>; def : Pat<(int_nvvm_d2ui_rm f64:$a), - (CVT_u32_f64 Float64Regs:$a, CvtRMI)>; + (CVT_u32_f64 $a, CvtRMI)>; def : Pat<(int_nvvm_d2ui_rp f64:$a), - (CVT_u32_f64 Float64Regs:$a, CvtRPI)>; + (CVT_u32_f64 $a, CvtRPI)>; def : Pat<(int_nvvm_i2d_rn i32:$a), - (CVT_f64_s32 Int32Regs:$a, CvtRN)>; + (CVT_f64_s32 $a, CvtRN)>; def : Pat<(int_nvvm_i2d_rz i32:$a), - (CVT_f64_s32 Int32Regs:$a, CvtRZ)>; + (CVT_f64_s32 $a, CvtRZ)>; def : Pat<(int_nvvm_i2d_rm i32:$a), - (CVT_f64_s32 Int32Regs:$a, CvtRM)>; + (CVT_f64_s32 $a, CvtRM)>; def : Pat<(int_nvvm_i2d_rp i32:$a), - (CVT_f64_s32 Int32Regs:$a, CvtRP)>; + (CVT_f64_s32 $a, CvtRP)>; def : Pat<(int_nvvm_ui2d_rn i32:$a), - (CVT_f64_u32 Int32Regs:$a, CvtRN)>; + (CVT_f64_u32 $a, CvtRN)>; def : Pat<(int_nvvm_ui2d_rz i32:$a), - (CVT_f64_u32 Int32Regs:$a, CvtRZ)>; + (CVT_f64_u32 $a, CvtRZ)>; def : Pat<(int_nvvm_ui2d_rm i32:$a), - (CVT_f64_u32 Int32Regs:$a, CvtRM)>; + (CVT_f64_u32 $a, CvtRM)>; def : Pat<(int_nvvm_ui2d_rp i32:$a), - (CVT_f64_u32 Int32Regs:$a, CvtRP)>; + (CVT_f64_u32 $a, CvtRP)>; def : Pat<(int_nvvm_f2i_rn_ftz f32:$a), - (CVT_s32_f32 Float32Regs:$a, CvtRNI_FTZ)>; + (CVT_s32_f32 $a, CvtRNI_FTZ)>; def : Pat<(int_nvvm_f2i_rn f32:$a), - (CVT_s32_f32 Float32Regs:$a, CvtRNI)>; + (CVT_s32_f32 $a, CvtRNI)>; def : Pat<(int_nvvm_f2i_rz_ftz f32:$a), - (CVT_s32_f32 Float32Regs:$a, CvtRZI_FTZ)>; + (CVT_s32_f32 $a, CvtRZI_FTZ)>; def : Pat<(int_nvvm_f2i_rz f32:$a), - (CVT_s32_f32 Float32Regs:$a, CvtRZI)>; + (CVT_s32_f32 $a, CvtRZI)>; def : Pat<(int_nvvm_f2i_rm_ftz f32:$a), - (CVT_s32_f32 Float32Regs:$a, CvtRMI_FTZ)>; + (CVT_s32_f32 $a, CvtRMI_FTZ)>; def : Pat<(int_nvvm_f2i_rm f32:$a), - (CVT_s32_f32 Float32Regs:$a, CvtRMI)>; + (CVT_s32_f32 $a, CvtRMI)>; def : Pat<(int_nvvm_f2i_rp_ftz f32:$a), - (CVT_s32_f32 Float32Regs:$a, CvtRPI_FTZ)>; + (CVT_s32_f32 $a, CvtRPI_FTZ)>; def : Pat<(int_nvvm_f2i_rp f32:$a), - (CVT_s32_f32 Float32Regs:$a, CvtRPI)>; + (CVT_s32_f32 $a, CvtRPI)>; def : Pat<(int_nvvm_f2ui_rn_ftz f32:$a), - (CVT_u32_f32 Float32Regs:$a, CvtRNI_FTZ)>; + (CVT_u32_f32 $a, CvtRNI_FTZ)>; def : Pat<(int_nvvm_f2ui_rn f32:$a), - (CVT_u32_f32 Float32Regs:$a, CvtRNI)>; + (CVT_u32_f32 $a, CvtRNI)>; def : Pat<(int_nvvm_f2ui_rz_ftz f32:$a), - (CVT_u32_f32 Float32Regs:$a, CvtRZI_FTZ)>; + (CVT_u32_f32 $a, CvtRZI_FTZ)>; def : Pat<(int_nvvm_f2ui_rz f32:$a), - (CVT_u32_f32 Float32Regs:$a, CvtRZI)>; + (CVT_u32_f32 $a, CvtRZI)>; def : Pat<(int_nvvm_f2ui_rm_ftz f32:$a), - (CVT_u32_f32 Float32Regs:$a, CvtRMI_FTZ)>; + (CVT_u32_f32 $a, CvtRMI_FTZ)>; def : Pat<(int_nvvm_f2ui_rm f32:$a), - (CVT_u32_f32 Float32Regs:$a, CvtRMI)>; + (CVT_u32_f32 $a, CvtRMI)>; def : Pat<(int_nvvm_f2ui_rp_ftz f32:$a), - (CVT_u32_f32 Float32Regs:$a, CvtRPI_FTZ)>; + (CVT_u32_f32 $a, CvtRPI_FTZ)>; def : Pat<(int_nvvm_f2ui_rp f32:$a), - (CVT_u32_f32 Float32Regs:$a, CvtRPI)>; + (CVT_u32_f32 $a, CvtRPI)>; def : Pat<(int_nvvm_i2f_rn i32:$a), - (CVT_f32_s32 Int32Regs:$a, CvtRN)>; + (CVT_f32_s32 $a, CvtRN)>; def : Pat<(int_nvvm_i2f_rz i32:$a), - (CVT_f32_s32 Int32Regs:$a, CvtRZ)>; + (CVT_f32_s32 $a, CvtRZ)>; def : Pat<(int_nvvm_i2f_rm i32:$a), - (CVT_f32_s32 Int32Regs:$a, CvtRM)>; + (CVT_f32_s32 $a, CvtRM)>; def : Pat<(int_nvvm_i2f_rp i32:$a), - (CVT_f32_s32 Int32Regs:$a, CvtRP)>; + (CVT_f32_s32 $a, CvtRP)>; def : Pat<(int_nvvm_ui2f_rn i32:$a), - (CVT_f32_u32 Int32Regs:$a, CvtRN)>; + (CVT_f32_u32 $a, CvtRN)>; def : Pat<(int_nvvm_ui2f_rz i32:$a), - (CVT_f32_u32 Int32Regs:$a, CvtRZ)>; + (CVT_f32_u32 $a, CvtRZ)>; def : Pat<(int_nvvm_ui2f_rm i32:$a), - (CVT_f32_u32 Int32Regs:$a, CvtRM)>; + (CVT_f32_u32 $a, CvtRM)>; def : Pat<(int_nvvm_ui2f_rp i32:$a), - (CVT_f32_u32 Int32Regs:$a, CvtRP)>; + (CVT_f32_u32 $a, CvtRP)>; def : Pat<(int_nvvm_ff2bf16x2_rn f32:$a, f32:$b), - (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; + (CVT_bf16x2_f32 $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff2bf16x2_rn_relu f32:$a, f32:$b), - (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; + (CVT_bf16x2_f32 $a, $b, CvtRN_RELU)>; def : Pat<(int_nvvm_ff2bf16x2_rz f32:$a, f32:$b), - (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>; + (CVT_bf16x2_f32 $a, $b, CvtRZ)>; def : Pat<(int_nvvm_ff2bf16x2_rz_relu f32:$a, f32:$b), - (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>; + (CVT_bf16x2_f32 $a, $b, CvtRZ_RELU)>; def : Pat<(int_nvvm_ff2f16x2_rn f32:$a, f32:$b), - (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; + (CVT_f16x2_f32 $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff2f16x2_rn_relu f32:$a, f32:$b), - (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; + (CVT_f16x2_f32 $a, $b, CvtRN_RELU)>; def : Pat<(int_nvvm_ff2f16x2_rz f32:$a, f32:$b), - (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>; + (CVT_f16x2_f32 $a, $b, CvtRZ)>; def : Pat<(int_nvvm_ff2f16x2_rz_relu f32:$a, f32:$b), - (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>; + (CVT_f16x2_f32 $a, $b, CvtRZ_RELU)>; def : Pat<(int_nvvm_f2bf16_rn f32:$a), - (CVT_bf16_f32 Float32Regs:$a, CvtRN)>; + (CVT_bf16_f32 $a, CvtRN)>; def : Pat<(int_nvvm_f2bf16_rn_relu f32:$a), - (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>; + (CVT_bf16_f32 $a, CvtRN_RELU)>; def : Pat<(int_nvvm_f2bf16_rz f32:$a), - (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>; + (CVT_bf16_f32 $a, CvtRZ)>; def : Pat<(int_nvvm_f2bf16_rz_relu f32:$a), - (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>; + (CVT_bf16_f32 $a, CvtRZ_RELU)>; def CVT_tf32_f32 : NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a), @@ -1682,125 +1682,125 @@ def INT_NVVM_D2I_HI : F_MATH_1< Int32Regs, Float64Regs, int_nvvm_d2i_hi>; def : Pat<(int_nvvm_f2ll_rn_ftz f32:$a), - (CVT_s64_f32 Float32Regs:$a, CvtRNI_FTZ)>; + (CVT_s64_f32 $a, CvtRNI_FTZ)>; def : Pat<(int_nvvm_f2ll_rn f32:$a), - (CVT_s64_f32 Float32Regs:$a, CvtRNI)>; + (CVT_s64_f32 $a, CvtRNI)>; def : Pat<(int_nvvm_f2ll_rz_ftz f32:$a), - (CVT_s64_f32 Float32Regs:$a, CvtRZI_FTZ)>; + (CVT_s64_f32 $a, CvtRZI_FTZ)>; def : Pat<(int_nvvm_f2ll_rz f32:$a), - (CVT_s64_f32 Float32Regs:$a, CvtRZI)>; + (CVT_s64_f32 $a, CvtRZI)>; def : Pat<(int_nvvm_f2ll_rm_ftz f32:$a), - (CVT_s64_f32 Float32Regs:$a, CvtRMI_FTZ)>; + (CVT_s64_f32 $a, CvtRMI_FTZ)>; def : Pat<(int_nvvm_f2ll_rm f32:$a), - (CVT_s64_f32 Float32Regs:$a, CvtRMI)>; + (CVT_s64_f32 $a, CvtRMI)>; def : Pat<(int_nvvm_f2ll_rp_ftz f32:$a), - (CVT_s64_f32 Float32Regs:$a, CvtRPI_FTZ)>; + (CVT_s64_f32 $a, CvtRPI_FTZ)>; def : Pat<(int_nvvm_f2ll_rp f32:$a), - (CVT_s64_f32 Float32Regs:$a, CvtRPI)>; + (CVT_s64_f32 $a, CvtRPI)>; def : Pat<(int_nvvm_f2ull_rn_ftz f32:$a), - (CVT_u64_f32 Float32Regs:$a, CvtRNI_FTZ)>; + (CVT_u64_f32 $a, CvtRNI_FTZ)>; def : Pat<(int_nvvm_f2ull_rn f32:$a), - (CVT_u64_f32 Float32Regs:$a, CvtRNI)>; + (CVT_u64_f32 $a, CvtRNI)>; def : Pat<(int_nvvm_f2ull_rz_ftz f32:$a), - (CVT_u64_f32 Float32Regs:$a, CvtRZI_FTZ)>; + (CVT_u64_f32 $a, CvtRZI_FTZ)>; def : Pat<(int_nvvm_f2ull_rz f32:$a), - (CVT_u64_f32 Float32Regs:$a, CvtRZI)>; + (CVT_u64_f32 $a, CvtRZI)>; def : Pat<(int_nvvm_f2ull_rm_ftz f32:$a), - (CVT_u64_f32 Float32Regs:$a, CvtRMI_FTZ)>; + (CVT_u64_f32 $a, CvtRMI_FTZ)>; def : Pat<(int_nvvm_f2ull_rm f32:$a), - (CVT_u64_f32 Float32Regs:$a, CvtRMI)>; + (CVT_u64_f32 $a, CvtRMI)>; def : Pat<(int_nvvm_f2ull_rp_ftz f32:$a), - (CVT_u64_f32 Float32Regs:$a, CvtRPI_FTZ)>; + (CVT_u64_f32 $a, CvtRPI_FTZ)>; def : Pat<(int_nvvm_f2ull_rp f32:$a), - (CVT_u64_f32 Float32Regs:$a, CvtRPI)>; + (CVT_u64_f32 $a, CvtRPI)>; def : Pat<(int_nvvm_d2ll_rn f64:$a), - (CVT_s64_f64 Float64Regs:$a, CvtRNI)>; + (CVT_s64_f64 $a, CvtRNI)>; def : Pat<(int_nvvm_d2ll_rz f64:$a), - (CVT_s64_f64 Float64Regs:$a, CvtRZI)>; + (CVT_s64_f64 $a, CvtRZI)>; def : Pat<(int_nvvm_d2ll_rm f64:$a), - (CVT_s64_f64 Float64Regs:$a, CvtRMI)>; + (CVT_s64_f64 $a, CvtRMI)>; def : Pat<(int_nvvm_d2ll_rp f64:$a), - (CVT_s64_f64 Float64Regs:$a, CvtRPI)>; + (CVT_s64_f64 $a, CvtRPI)>; def : Pat<(int_nvvm_d2ull_rn f64:$a), - (CVT_u64_f64 Float64Regs:$a, CvtRNI)>; + (CVT_u64_f64 $a, CvtRNI)>; def : Pat<(int_nvvm_d2ull_rz f64:$a), - (CVT_u64_f64 Float64Regs:$a, CvtRZI)>; + (CVT_u64_f64 $a, CvtRZI)>; def : Pat<(int_nvvm_d2ull_rm f64:$a), - (CVT_u64_f64 Float64Regs:$a, CvtRMI)>; + (CVT_u64_f64 $a, CvtRMI)>; def : Pat<(int_nvvm_d2ull_rp f64:$a), - (CVT_u64_f64 Float64Regs:$a, CvtRPI)>; + (CVT_u64_f64 $a, CvtRPI)>; def : Pat<(int_nvvm_ll2f_rn i64:$a), - (CVT_f32_s64 Int64Regs:$a, CvtRN)>; + (CVT_f32_s64 $a, CvtRN)>; def : Pat<(int_nvvm_ll2f_rz i64:$a), - (CVT_f32_s64 Int64Regs:$a, CvtRZ)>; + (CVT_f32_s64 $a, CvtRZ)>; def : Pat<(int_nvvm_ll2f_rm i64:$a), - (CVT_f32_s64 Int64Regs:$a, CvtRM)>; + (CVT_f32_s64 $a, CvtRM)>; def : Pat<(int_nvvm_ll2f_rp i64:$a), - (CVT_f32_s64 Int64Regs:$a, CvtRP)>; + (CVT_f32_s64 $a, CvtRP)>; def : Pat<(int_nvvm_ull2f_rn i64:$a), - (CVT_f32_u64 Int64Regs:$a, CvtRN)>; + (CVT_f32_u64 $a, CvtRN)>; def : Pat<(int_nvvm_ull2f_rz i64:$a), - (CVT_f32_u64 Int64Regs:$a, CvtRZ)>; + (CVT_f32_u64 $a, CvtRZ)>; def : Pat<(int_nvvm_ull2f_rm i64:$a), - (CVT_f32_u64 Int64Regs:$a, CvtRM)>; + (CVT_f32_u64 $a, CvtRM)>; def : Pat<(int_nvvm_ull2f_rp i64:$a), - (CVT_f32_u64 Int64Regs:$a, CvtRP)>; + (CVT_f32_u64 $a, CvtRP)>; def : Pat<(int_nvvm_ll2d_rn i64:$a), - (CVT_f64_s64 Int64Regs:$a, CvtRN)>; + (CVT_f64_s64 $a, CvtRN)>; def : Pat<(int_nvvm_ll2d_rz i64:$a), - (CVT_f64_s64 Int64Regs:$a, CvtRZ)>; + (CVT_f64_s64 $a, CvtRZ)>; def : Pat<(int_nvvm_ll2d_rm i64:$a), - (CVT_f64_s64 Int64Regs:$a, CvtRM)>; + (CVT_f64_s64 $a, CvtRM)>; def : Pat<(int_nvvm_ll2d_rp i64:$a), - (CVT_f64_s64 Int64Regs:$a, CvtRP)>; + (CVT_f64_s64 $a, CvtRP)>; def : Pat<(int_nvvm_ull2d_rn i64:$a), - (CVT_f64_u64 Int64Regs:$a, CvtRN)>; + (CVT_f64_u64 $a, CvtRN)>; def : Pat<(int_nvvm_ull2d_rz i64:$a), - (CVT_f64_u64 Int64Regs:$a, CvtRZ)>; + (CVT_f64_u64 $a, CvtRZ)>; def : Pat<(int_nvvm_ull2d_rm i64:$a), - (CVT_f64_u64 Int64Regs:$a, CvtRM)>; + (CVT_f64_u64 $a, CvtRM)>; def : Pat<(int_nvvm_ull2d_rp i64:$a), - (CVT_f64_u64 Int64Regs:$a, CvtRP)>; + (CVT_f64_u64 $a, CvtRP)>; def : Pat<(int_nvvm_f2h_rn_ftz f32:$a), - (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>; + (CVT_f16_f32 $a, CvtRN_FTZ)>; def : Pat<(int_nvvm_f2h_rn f32:$a), - (CVT_f16_f32 Float32Regs:$a, CvtRN)>; + (CVT_f16_f32 $a, CvtRN)>; def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b), - (CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; + (CVT_e4m3x2_f32 $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b), - (CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; + (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>; def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b), - (CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; + (CVT_e5m2x2_f32 $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b), - (CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; + (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>; def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn Int32Regs:$a), - (CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN)>; + (CVT_e4m3x2_f16x2 $a, CvtRN)>; def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu Int32Regs:$a), - (CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN_RELU)>; + (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>; def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn Int32Regs:$a), - (CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN)>; + (CVT_e5m2x2_f16x2 $a, CvtRN)>; def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu Int32Regs:$a), - (CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN_RELU)>; + (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>; def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn Int16Regs:$a), - (CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN)>; + (CVT_f16x2_e4m3x2 $a, CvtRN)>; def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu Int16Regs:$a), - (CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN_RELU)>; + (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>; def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a), - (CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN)>; + (CVT_f16x2_e5m2x2 $a, CvtRN)>; def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a), - (CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN_RELU)>; + (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>; // // FNS @@ -1823,9 +1823,9 @@ def INT_FNS_rii : INT_FNS_MBO<(ins Int32Regs:$mask, i32imm:$base, i32imm:$ def INT_FNS_irr : INT_FNS_MBO<(ins i32imm:$mask, Int32Regs:$base, Int32Regs:$offset), (int_nvvm_fns imm:$mask, i32:$base, i32:$offset)>; def INT_FNS_iri : INT_FNS_MBO<(ins i32imm:$mask, Int32Regs:$base, i32imm:$offset), - (int_nvvm_fns imm:$mask, Int32Regs:$base, imm:$offset)>; + (int_nvvm_fns imm:$mask, i32:$base, imm:$offset)>; def INT_FNS_iir : INT_FNS_MBO<(ins i32imm:$mask, i32imm:$base, Int32Regs:$offset), - (int_nvvm_fns imm:$mask, imm:$base, Int32Regs:$offset)>; + (int_nvvm_fns imm:$mask, imm:$base, i32:$offset)>; def INT_FNS_iii : INT_FNS_MBO<(ins i32imm:$mask, i32imm:$base, i32imm:$offset), (int_nvvm_fns imm:$mask, imm:$base, imm:$offset)>; @@ -2796,10 +2796,10 @@ defm cvta_to_const : G_TO_NG<"const">; defm cvta_param : NG_TO_G<"param">; def : Pat<(int_nvvm_ptr_param_to_gen i32:$src), - (cvta_param Int32Regs:$src)>; + (cvta_param $src)>; def : Pat<(int_nvvm_ptr_param_to_gen i64:$src), - (cvta_param_64 Int64Regs:$src)>; + (cvta_param_64 $src)>; // nvvm.ptr.gen.to.param def : Pat<(int_nvvm_ptr_gen_to_param i32:$src), @@ -2933,8 +2933,8 @@ def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; def : Pat<(int_nvvm_swap_lo_hi_b64 i64:$src), - (V2I32toI64 (I64toI32H Int64Regs:$src), - (I64toI32L Int64Regs:$src))> ; + (V2I32toI64 (I64toI32H $src), + (I64toI32L $src))> ; //----------------------------------- // Texture Intrinsics @@ -5040,21 +5040,21 @@ def TXQ_NUM_MIPMAP_LEVELS_I } def : Pat<(int_nvvm_txq_channel_order i64:$a), - (TXQ_CHANNEL_ORDER_R i64:$a)>; + (TXQ_CHANNEL_ORDER_R $a)>; def : Pat<(int_nvvm_txq_channel_data_type i64:$a), - (TXQ_CHANNEL_DATA_TYPE_R i64:$a)>; + (TXQ_CHANNEL_DATA_TYPE_R $a)>; def : Pat<(int_nvvm_txq_width i64:$a), - (TXQ_WIDTH_R i64:$a)>; + (TXQ_WIDTH_R $a)>; def : Pat<(int_nvvm_txq_height i64:$a), - (TXQ_HEIGHT_R i64:$a)>; + (TXQ_HEIGHT_R $a)>; def : Pat<(int_nvvm_txq_depth i64:$a), - (TXQ_DEPTH_R i64:$a)>; + (TXQ_DEPTH_R $a)>; def : Pat<(int_nvvm_txq_array_size i64:$a), - (TXQ_ARRAY_SIZE_R i64:$a)>; + (TXQ_ARRAY_SIZE_R $a)>; def : Pat<(int_nvvm_txq_num_samples i64:$a), - (TXQ_NUM_SAMPLES_R i64:$a)>; + (TXQ_NUM_SAMPLES_R $a)>; def : Pat<(int_nvvm_txq_num_mipmap_levels i64:$a), - (TXQ_NUM_MIPMAP_LEVELS_R i64:$a)>; + (TXQ_NUM_MIPMAP_LEVELS_R $a)>; //----------------------------------- @@ -5113,17 +5113,17 @@ def SUQ_ARRAY_SIZE_I } def : Pat<(int_nvvm_suq_channel_order i64:$a), - (SUQ_CHANNEL_ORDER_R Int64Regs:$a)>; + (SUQ_CHANNEL_ORDER_R $a)>; def : Pat<(int_nvvm_suq_channel_data_type i64:$a), - (SUQ_CHANNEL_DATA_TYPE_R Int64Regs:$a)>; + (SUQ_CHANNEL_DATA_TYPE_R $a)>; def : Pat<(int_nvvm_suq_width i64:$a), - (SUQ_WIDTH_R Int64Regs:$a)>; + (SUQ_WIDTH_R $a)>; def : Pat<(int_nvvm_suq_height i64:$a), - (SUQ_HEIGHT_R Int64Regs:$a)>; + (SUQ_HEIGHT_R $a)>; def : Pat<(int_nvvm_suq_depth i64:$a), - (SUQ_DEPTH_R Int64Regs:$a)>; + (SUQ_DEPTH_R $a)>; def : Pat<(int_nvvm_suq_array_size i64:$a), - (SUQ_ARRAY_SIZE_R Int64Regs:$a)>; + (SUQ_ARRAY_SIZE_R $a)>; //===- Handle Query -------------------------------------------------------===// diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp index 42043ad..74ce6a9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp @@ -34,19 +34,18 @@ void NVPTXSubtarget::anchor() {} NVPTXSubtarget &NVPTXSubtarget::initializeSubtargetDependencies(StringRef CPU, StringRef FS) { - // Provide the default CPU if we don't have one. - TargetName = std::string(CPU.empty() ? "sm_30" : CPU); + TargetName = std::string(CPU); - ParseSubtargetFeatures(TargetName, /*TuneCPU*/ TargetName, FS); + ParseSubtargetFeatures(getTargetName(), /*TuneCPU=*/getTargetName(), FS); - // Re-map SM version numbers, SmVersion carries the regular SMs which do - // have relative order, while FullSmVersion allows distinguishing sm_90 from - // sm_90a, which would *not* be a subset of sm_91. - SmVersion = getSmVersion(); + // Re-map SM version numbers, SmVersion carries the regular SMs which do + // have relative order, while FullSmVersion allows distinguishing sm_90 from + // sm_90a, which would *not* be a subset of sm_91. + SmVersion = getSmVersion(); - // Set default to PTX 6.0 (CUDA 9.0) - if (PTXVersion == 0) { - PTXVersion = 60; + // Set default to PTX 6.0 (CUDA 9.0) + if (PTXVersion == 0) { + PTXVersion = 60; } return *this; diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index 7555a23..bbc1cca 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -111,7 +111,12 @@ public: // - 0 represents base GPU model, // - non-zero value identifies particular architecture-accelerated variant. bool hasAAFeatures() const { return getFullSmVersion() % 10; } - std::string getTargetName() const { return TargetName; } + + // If the user did not provide a target we default to the `sm_30` target. + std::string getTargetName() const { + return TargetName.empty() ? "sm_30" : TargetName; + } + bool hasTargetName() const { return !TargetName.empty(); } // Get maximum value of required alignments among the supported data types. // From the PTX ISA doc, section 8.2.3: diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index b3b2880..6d4b82a 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -255,7 +255,10 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { PB.registerPipelineStartEPCallback( [this](ModulePassManager &PM, OptimizationLevel Level) { FunctionPassManager FPM; - FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion())); + // We do not want to fold out calls to nvvm.reflect early if the user + // has not provided a target architecture just yet. + if (Subtarget.hasTargetName()) + FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion())); // Note: NVVMIntrRangePass was causing numerical discrepancies at one // point, if issues crop up, consider disabling. FPM.addPass(NVVMIntrRangePass()); diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp index 98bffd9..0f2bec7 100644 --- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp @@ -311,11 +311,13 @@ std::optional<unsigned> getMaxNReg(const Function &F) { } bool isKernelFunction(const Function &F) { + if (F.getCallingConv() == CallingConv::PTX_Kernel) + return true; + if (const auto X = findOneNVVMAnnotation(&F, "kernel")) return (*X == 1); - // There is no NVVM metadata, check the calling convention - return F.getCallingConv() == CallingConv::PTX_Kernel; + return false; } MaybeAlign getAlign(const Function &F, unsigned Index) { diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp index 56525a1..0cd584c 100644 --- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp +++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp @@ -21,6 +21,7 @@ #include "NVPTX.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -219,7 +220,13 @@ bool NVVMReflect::runOnFunction(Function &F) { return runNVVMReflect(F, SmVersion); } -NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {} +NVVMReflectPass::NVVMReflectPass() { + // Get the CPU string from the command line if not provided. + std::string MCPU = codegen::getMCPU(); + StringRef SM = MCPU; + if (!SM.consume_front("sm_") || SM.consumeInteger(10, SmVersion)) + SmVersion = 0; +} PreservedAnalyses NVVMReflectPass::run(Function &F, FunctionAnalysisManager &AM) { diff --git a/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp b/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp index 2e0ee59..d1daf7c 100644 --- a/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp @@ -5473,10 +5473,10 @@ void PPCDAGToDAGISel::Select(SDNode *N) { // generate secure plt code for TLS symbols. getGlobalBaseReg(); } break; - case PPCISD::CALL: { - if (PPCLowering->getPointerTy(CurDAG->getDataLayout()) != MVT::i32 || - !TM.isPositionIndependent() || !Subtarget->isSecurePlt() || - !Subtarget->isTargetELF()) + case PPCISD::CALL: + case PPCISD::CALL_RM: { + if (Subtarget->isPPC64() || !TM.isPositionIndependent() || + !Subtarget->isSecurePlt() || !Subtarget->isTargetELF()) break; SDValue Op = N->getOperand(1); @@ -5489,8 +5489,7 @@ void PPCDAGToDAGISel::Select(SDNode *N) { if (ES->getTargetFlags() == PPCII::MO_PLT) getGlobalBaseReg(); } - } - break; + } break; case PPCISD::GlobalBaseReg: ReplaceNode(N, getGlobalBaseReg()); diff --git a/llvm/lib/Target/PowerPC/PPCInstrInfo.cpp b/llvm/lib/Target/PowerPC/PPCInstrInfo.cpp index 44f6db5..fa45a7f 100644 --- a/llvm/lib/Target/PowerPC/PPCInstrInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCInstrInfo.cpp @@ -643,8 +643,8 @@ bool PPCInstrInfo::shouldReduceRegisterPressure( }; // For now we only care about float and double type fma. - unsigned VSSRCLimit = TRI->getRegPressureSetLimit( - *MBB->getParent(), PPC::RegisterPressureSets::VSSRC); + unsigned VSSRCLimit = + RegClassInfo->getRegPressureSetLimit(PPC::RegisterPressureSets::VSSRC); // Only reduce register pressure when pressure is high. return GetMBBPressure(MBB)[PPC::RegisterPressureSets::VSSRC] > diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp index 9dcf2e9..2205c67 100644 --- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp +++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp @@ -734,6 +734,16 @@ public: VK == RISCVMCExpr::VK_RISCV_None; } + bool isUImm5GT3() const { + if (!isImm()) + return false; + RISCVMCExpr::VariantKind VK = RISCVMCExpr::VK_RISCV_None; + int64_t Imm; + bool IsConstantImm = evaluateConstantImm(getImm(), Imm, VK); + return IsConstantImm && isUInt<5>(Imm) && (Imm > 3) && + VK == RISCVMCExpr::VK_RISCV_None; + } + bool isUImm8GE32() const { int64_t Imm; RISCVMCExpr::VariantKind VK = RISCVMCExpr::VK_RISCV_None; @@ -1520,6 +1530,8 @@ bool RISCVAsmParser::matchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode, return generateImmOutOfRangeError(Operands, ErrorInfo, 0, (1 << 5) - 1); case Match_InvalidUImm5NonZero: return generateImmOutOfRangeError(Operands, ErrorInfo, 1, (1 << 5) - 1); + case Match_InvalidUImm5GT3: + return generateImmOutOfRangeError(Operands, ErrorInfo, 4, (1 << 5) - 1); case Match_InvalidUImm6: return generateImmOutOfRangeError(Operands, ErrorInfo, 0, (1 << 6) - 1); case Match_InvalidUImm7: @@ -1903,6 +1915,8 @@ ParseStatus RISCVAsmParser::parseCSRSystemRegister(OperandVector &Operands) { // Accept an immediate representing a named Sys Reg if it satisfies the // the required features. for (auto &Reg : Range) { + if (Reg.IsAltName || Reg.IsDeprecatedName) + continue; if (Reg.haveRequiredFeatures(STI->getFeatureBits())) return RISCVOperand::createSysReg(Reg.Name, S, Imm); } @@ -1940,22 +1954,27 @@ ParseStatus RISCVAsmParser::parseCSRSystemRegister(OperandVector &Operands) { return ParseStatus::Failure; const auto *SysReg = RISCVSysReg::lookupSysRegByName(Identifier); - if (!SysReg) - SysReg = RISCVSysReg::lookupSysRegByAltName(Identifier); - if (!SysReg) - if ((SysReg = RISCVSysReg::lookupSysRegByDeprecatedName(Identifier))) - Warning(S, "'" + Identifier + "' is a deprecated alias for '" + - SysReg->Name + "'"); - - // Accept a named Sys Reg if the required features are present. + if (SysReg) { + if (SysReg->IsDeprecatedName) { + // Lookup the undeprecated name. + auto Range = RISCVSysReg::lookupSysRegByEncoding(SysReg->Encoding); + for (auto &Reg : Range) { + if (Reg.IsAltName || Reg.IsDeprecatedName) + continue; + Warning(S, "'" + Identifier + "' is a deprecated alias for '" + + Reg.Name + "'"); + } + } + + // Accept a named Sys Reg if the required features are present. const auto &FeatureBits = getSTI().getFeatureBits(); if (!SysReg->haveRequiredFeatures(FeatureBits)) { const auto *Feature = llvm::find_if(RISCVFeatureKV, [&](auto Feature) { return SysReg->FeaturesRequired[Feature.Value]; }); auto ErrorMsg = std::string("system register '") + SysReg->Name + "' "; - if (SysReg->isRV32Only && FeatureBits[RISCV::Feature64Bit]) { + if (SysReg->IsRV32Only && FeatureBits[RISCV::Feature64Bit]) { ErrorMsg += "is RV32 only"; if (Feature != std::end(RISCVFeatureKV)) ErrorMsg += " and "; diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt index 4466164..98d3615 100644 --- a/llvm/lib/Target/RISCV/CMakeLists.txt +++ b/llvm/lib/Target/RISCV/CMakeLists.txt @@ -15,6 +15,7 @@ tablegen(LLVM RISCVGenRegisterBank.inc -gen-register-bank) tablegen(LLVM RISCVGenRegisterInfo.inc -gen-register-info) tablegen(LLVM RISCVGenSearchableTables.inc -gen-searchable-tables) tablegen(LLVM RISCVGenSubtargetInfo.inc -gen-subtarget) +tablegen(LLVM RISCVGenExegesis.inc -gen-exegesis) set(LLVM_TARGET_DEFINITIONS RISCVGISel.td) tablegen(LLVM RISCVGenGlobalISel.inc -gen-global-isel) diff --git a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp index 9901719..a490910 100644 --- a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp +++ b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp @@ -692,6 +692,14 @@ DecodeStatus RISCVDisassembler::getInstruction32(MCInst &MI, uint64_t &Size, "Qualcomm uC Conditional Select custom opcode table"); TRY_TO_DECODE_FEATURE(RISCV::FeatureVendorXqcilsm, DecoderTableXqcilsm32, "Qualcomm uC Load Store Multiple custom opcode table"); + TRY_TO_DECODE_FEATURE( + RISCV::FeatureVendorXqciac, DecoderTableXqciac32, + "Qualcomm uC Load-Store Address Calculation custom opcode table"); + TRY_TO_DECODE_FEATURE( + RISCV::FeatureVendorXqcicli, DecoderTableXqcicli32, + "Qualcomm uC Conditional Load Immediate custom opcode table"); + TRY_TO_DECODE_FEATURE(RISCV::FeatureVendorXqcicm, DecoderTableXqcicm32, + "Qualcomm uC Conditional Move custom opcode table"); TRY_TO_DECODE(true, DecoderTable32, "RISCV32 table"); return MCDisassembler::Fail; @@ -718,6 +726,12 @@ DecodeStatus RISCVDisassembler::getInstruction16(MCInst &MI, uint64_t &Size, TRY_TO_DECODE_FEATURE( RISCV::FeatureStdExtZcmp, DecoderTableRVZcmp16, "Zcmp table (16-bit Push/Pop & Double Move Instructions)"); + TRY_TO_DECODE_FEATURE( + RISCV::FeatureVendorXqciac, DecoderTableXqciac16, + "Qualcomm uC Load-Store Address Calculation custom 16bit opcode table"); + TRY_TO_DECODE_FEATURE( + RISCV::FeatureVendorXqcicm, DecoderTableXqcicm16, + "Qualcomm uC Conditional Move custom 16bit opcode table"); TRY_TO_DECODE_AND_ADD_SP(STI.hasFeature(RISCV::FeatureVendorXwchc), DecoderTableXwchc16, "WCH QingKe XW custom opcode table"); diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp index ef85057..3f1539d 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp @@ -80,7 +80,6 @@ private: bool selectFPCompare(MachineInstr &MI, MachineIRBuilder &MIB) const; void emitFence(AtomicOrdering FenceOrdering, SyncScope::ID FenceSSID, MachineIRBuilder &MIB) const; - bool selectMergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const; bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const; ComplexRendererFns selectShiftMask(MachineOperand &Root, @@ -732,8 +731,6 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) { } case TargetOpcode::G_IMPLICIT_DEF: return selectImplicitDef(MI, MIB); - case TargetOpcode::G_MERGE_VALUES: - return selectMergeValues(MI, MIB); case TargetOpcode::G_UNMERGE_VALUES: return selectUnmergeValues(MI, MIB); default: @@ -741,26 +738,13 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) { } } -bool RISCVInstructionSelector::selectMergeValues(MachineInstr &MI, - MachineIRBuilder &MIB) const { - assert(MI.getOpcode() == TargetOpcode::G_MERGE_VALUES); - - // Build a F64 Pair from operands - if (MI.getNumOperands() != 3) - return false; - Register Dst = MI.getOperand(0).getReg(); - Register Lo = MI.getOperand(1).getReg(); - Register Hi = MI.getOperand(2).getReg(); - if (!isRegInFprb(Dst) || !isRegInGprb(Lo) || !isRegInGprb(Hi)) - return false; - MI.setDesc(TII.get(RISCV::BuildPairF64Pseudo)); - return constrainSelectedInstRegOperands(MI, TII, TRI, RBI); -} - bool RISCVInstructionSelector::selectUnmergeValues( MachineInstr &MI, MachineIRBuilder &MIB) const { assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES); + if (!Subtarget->hasStdExtZfa()) + return false; + // Split F64 Src into two s32 parts if (MI.getNumOperands() != 3) return false; @@ -769,8 +753,17 @@ bool RISCVInstructionSelector::selectUnmergeValues( Register Hi = MI.getOperand(1).getReg(); if (!isRegInFprb(Src) || !isRegInGprb(Lo) || !isRegInGprb(Hi)) return false; - MI.setDesc(TII.get(RISCV::SplitF64Pseudo)); - return constrainSelectedInstRegOperands(MI, TII, TRI, RBI); + + MachineInstr *ExtractLo = MIB.buildInstr(RISCV::FMV_X_W_FPR64, {Lo}, {Src}); + if (!constrainSelectedInstRegOperands(*ExtractLo, TII, TRI, RBI)) + return false; + + MachineInstr *ExtractHi = MIB.buildInstr(RISCV::FMVH_X_D, {Hi}, {Src}); + if (!constrainSelectedInstRegOperands(*ExtractHi, TII, TRI, RBI)) + return false; + + MI.eraseFromParent(); + return true; } bool RISCVInstructionSelector::replacePtrWithInt(MachineOperand &Op, diff --git a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp index 8284737..6f06459 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp @@ -21,6 +21,7 @@ #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineMemOperand.h" +#include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/CodeGen/ValueTypes.h" @@ -132,7 +133,14 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) auto PtrVecTys = {nxv1p0, nxv2p0, nxv4p0, nxv8p0, nxv16p0}; - getActionDefinitionsBuilder({G_ADD, G_SUB, G_AND, G_OR, G_XOR}) + getActionDefinitionsBuilder({G_ADD, G_SUB}) + .legalFor({sXLen}) + .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST)) + .customFor(ST.is64Bit(), {s32}) + .widenScalarToNextPow2(0) + .clampScalar(0, sXLen, sXLen); + + getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) .legalFor({sXLen}) .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST)) .widenScalarToNextPow2(0) @@ -1330,6 +1338,24 @@ bool RISCVLegalizerInfo::legalizeCustom( return true; return Helper.lowerConstant(MI); } + case TargetOpcode::G_SUB: + case TargetOpcode::G_ADD: { + Helper.Observer.changingInstr(MI); + Helper.widenScalarSrc(MI, sXLen, 1, TargetOpcode::G_ANYEXT); + Helper.widenScalarSrc(MI, sXLen, 2, TargetOpcode::G_ANYEXT); + + Register DstALU = MRI.createGenericVirtualRegister(sXLen); + + MachineOperand &MO = MI.getOperand(0); + MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt()); + auto DstSext = MIRBuilder.buildSExtInReg(sXLen, DstALU, 32); + + MIRBuilder.buildInstr(TargetOpcode::G_TRUNC, {MO}, {DstSext}); + MO.setReg(DstALU); + + Helper.Observer.changedInstr(MI); + return true; + } case TargetOpcode::G_SEXT_INREG: { LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); int64_t SizeInBits = MI.getOperand(2).getImm(); diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp index eab4a5e..0cb1ef0 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp @@ -38,9 +38,12 @@ std::optional<MCFixupKind> RISCVAsmBackend::getFixupKind(StringRef Name) const { if (STI.getTargetTriple().isOSBinFormatELF()) { unsigned Type; Type = llvm::StringSwitch<unsigned>(Name) -#define ELF_RELOC(X, Y) .Case(#X, Y) +#define ELF_RELOC(NAME, ID) .Case(#NAME, ID) #include "llvm/BinaryFormat/ELFRelocs/RISCV.def" #undef ELF_RELOC +#define ELF_RISCV_NONSTANDARD_RELOC(_VENDOR, NAME, ID) .Case(#NAME, ID) +#include "llvm/BinaryFormat/ELFRelocs/RISCV_nonstandard.def" +#undef ELF_RISCV_NONSTANDARD_RELOC .Case("BFD_RELOC_NONE", ELF::R_RISCV_NONE) .Case("BFD_RELOC_32", ELF::R_RISCV_32) .Case("BFD_RELOC_64", ELF::R_RISCV_64) diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h index b9f4db0..7048e40 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h @@ -302,6 +302,7 @@ enum OperandType : unsigned { OPERAND_UIMM4, OPERAND_UIMM5, OPERAND_UIMM5_NONZERO, + OPERAND_UIMM5_GT3, OPERAND_UIMM5_LSB0, OPERAND_UIMM6, OPERAND_UIMM6_LSB0, @@ -453,8 +454,6 @@ int getLoadFPImm(APFloat FPImm); namespace RISCVSysReg { struct SysReg { const char Name[32]; - const char AltName[32]; - const char DeprecatedName[32]; unsigned Encoding; // FIXME: add these additional fields when needed. // Privilege Access: Read, Write, Read-Only. @@ -466,11 +465,13 @@ struct SysReg { // Register number without the privilege bits. // unsigned Number; FeatureBitset FeaturesRequired; - bool isRV32Only; + bool IsRV32Only; + bool IsAltName; + bool IsDeprecatedName; bool haveRequiredFeatures(const FeatureBitset &ActiveFeatures) const { // Not in 32-bit mode. - if (isRV32Only && ActiveFeatures[RISCV::Feature64Bit]) + if (IsRV32Only && ActiveFeatures[RISCV::Feature64Bit]) return false; // No required feature associated with the system register. if (FeaturesRequired.none()) @@ -479,6 +480,7 @@ struct SysReg { } }; +#define GET_SysRegEncodings_DECL #define GET_SysRegsList_DECL #include "RISCVGenSearchableTables.inc" } // end namespace RISCVSysReg diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp index d36c0d7..d525471 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp @@ -121,6 +121,8 @@ void RISCVInstPrinter::printCSRSystemRegister(const MCInst *MI, unsigned OpNo, unsigned Imm = MI->getOperand(OpNo).getImm(); auto Range = RISCVSysReg::lookupSysRegByEncoding(Imm); for (auto &Reg : Range) { + if (Reg.IsAltName || Reg.IsDeprecatedName) + continue; if (Reg.haveRequiredFeatures(STI.getFeatureBits())) { markup(O, Markup::Register) << Reg.Name; return; diff --git a/llvm/lib/Target/RISCV/RISCV.td b/llvm/lib/Target/RISCV/RISCV.td index 9631241..4e0c64a 100644 --- a/llvm/lib/Target/RISCV/RISCV.td +++ b/llvm/lib/Target/RISCV/RISCV.td @@ -64,6 +64,12 @@ include "RISCVSchedXiangShanNanHu.td" include "RISCVProcessors.td" //===----------------------------------------------------------------------===// +// Pfm Counters +//===----------------------------------------------------------------------===// + +include "RISCVPfmCounters.td" + +//===----------------------------------------------------------------------===// // Define the RISC-V target. //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVCombine.td b/llvm/lib/Target/RISCV/RISCVCombine.td index 030613a..995dd0c 100644 --- a/llvm/lib/Target/RISCV/RISCVCombine.td +++ b/llvm/lib/Target/RISCV/RISCVCombine.td @@ -25,5 +25,5 @@ def RISCVPostLegalizerCombiner : GICombiner<"RISCVPostLegalizerCombinerImpl", [sub_to_add, combines_for_extload, redundant_and, identity_combines, shift_immed_chain, - commute_constant_to_rhs]> { + commute_constant_to_rhs, simplify_neg_minmax]> { } diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index dfc56588..01bc5387 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -844,6 +844,10 @@ def HasStdExtH : Predicate<"Subtarget->hasStdExtH()">, // Supervisor extensions +def FeatureStdExtSdext : RISCVExperimentalExtension<1, 0, "External debugger">; + +def FeatureStdExtSdtrig : RISCVExperimentalExtension<1, 0, "Debugger triggers">; + def FeatureStdExtShgatpa : RISCVExtension<1, 0, "SvNNx4 mode supported for all modes supported by satp, as well as Bare">; @@ -1274,6 +1278,30 @@ def HasVendorXqcilsm AssemblerPredicate<(all_of FeatureVendorXqcilsm), "'Xqcilsm' (Qualcomm uC Load Store Multiple Extension)">; +def FeatureVendorXqciac + : RISCVExperimentalExtension<0, 2, "Qualcomm uC Load-Store Address Calculation Extension", + [FeatureStdExtZca]>; +def HasVendorXqciac + : Predicate<"Subtarget->hasVendorXqciac()">, + AssemblerPredicate<(all_of FeatureVendorXqciac), + "'Xqciac' (Qualcomm uC Load-Store Address Calculation Extension)">; + +def FeatureVendorXqcicli + : RISCVExperimentalExtension<0, 2, + "Qualcomm uC Conditional Load Immediate Extension">; +def HasVendorXqcicli + : Predicate<"Subtarget->hasVendorXqcicli()">, + AssemblerPredicate<(all_of FeatureVendorXqcicli), + "'Xqcicli' (Qualcomm uC Conditional Load Immediate Extension)">; + +def FeatureVendorXqcicm + : RISCVExperimentalExtension<0, 2, "Qualcomm uC Conditional Move Extension", + [FeatureStdExtZca]>; +def HasVendorXqcicm + : Predicate<"Subtarget->hasVendorXqcicm()">, + AssemblerPredicate<(all_of FeatureVendorXqcicm), + "'Xqcicm' (Qualcomm uC Conditional Move Extension)">; + //===----------------------------------------------------------------------===// // LLVM specific features and extensions //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index cda64ae..6c58989 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -5104,6 +5104,7 @@ static SDValue lowerShuffleViaVRegSplitting(ShuffleVectorSDNode *SVN, SDValue V1 = SVN->getOperand(0); SDValue V2 = SVN->getOperand(1); ArrayRef<int> Mask = SVN->getMask(); + unsigned NumElts = VT.getVectorNumElements(); // If we don't know exact data layout, not much we can do. If this // is already m1 or smaller, no point in splitting further. @@ -5120,70 +5121,58 @@ static SDValue lowerShuffleViaVRegSplitting(ShuffleVectorSDNode *SVN, MVT ElemVT = VT.getVectorElementType(); unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits(); + unsigned VRegsPerSrc = NumElts / ElemsPerVReg; + + SmallVector<std::pair<int, SmallVector<int>>> + OutMasks(VRegsPerSrc, {-1, {}}); + + // Check if our mask can be done as a 1-to-1 mapping from source + // to destination registers in the group without needing to + // write each destination more than once. + for (unsigned DstIdx = 0; DstIdx < Mask.size(); DstIdx++) { + int DstVecIdx = DstIdx / ElemsPerVReg; + int DstSubIdx = DstIdx % ElemsPerVReg; + int SrcIdx = Mask[DstIdx]; + if (SrcIdx < 0 || (unsigned)SrcIdx >= 2 * NumElts) + continue; + int SrcVecIdx = SrcIdx / ElemsPerVReg; + int SrcSubIdx = SrcIdx % ElemsPerVReg; + if (OutMasks[DstVecIdx].first == -1) + OutMasks[DstVecIdx].first = SrcVecIdx; + if (OutMasks[DstVecIdx].first != SrcVecIdx) + // Note: This case could easily be handled by keeping track of a chain + // of source values and generating two element shuffles below. This is + // less an implementation question, and more a profitability one. + return SDValue(); + + OutMasks[DstVecIdx].second.resize(ElemsPerVReg, -1); + OutMasks[DstVecIdx].second[DstSubIdx] = SrcSubIdx; + } EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); MVT OneRegVT = MVT::getVectorVT(ElemVT, ElemsPerVReg); MVT M1VT = getContainerForFixedLengthVector(DAG, OneRegVT, Subtarget); assert(M1VT == getLMUL1VT(M1VT)); unsigned NumOpElts = M1VT.getVectorMinNumElements(); - unsigned NormalizedVF = ContainerVT.getVectorMinNumElements(); - unsigned NumOfSrcRegs = NormalizedVF / NumOpElts; - unsigned NumOfDestRegs = NormalizedVF / NumOpElts; + SDValue Vec = DAG.getUNDEF(ContainerVT); // The following semantically builds up a fixed length concat_vector // of the component shuffle_vectors. We eagerly lower to scalable here // to avoid DAG combining it back to a large shuffle_vector again. V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget); V2 = convertToScalableVector(ContainerVT, V2, DAG, Subtarget); - SmallVector<SDValue> SubRegs(NumOfDestRegs); - unsigned RegCnt = 0; - unsigned PrevCnt = 0; - processShuffleMasks( - Mask, NumOfSrcRegs, NumOfDestRegs, NumOfDestRegs, - [&]() { - PrevCnt = RegCnt; - ++RegCnt; - }, - [&, &DAG = DAG](ArrayRef<int> SrcSubMask, unsigned SrcVecIdx, - unsigned DstVecIdx) { - SDValue SrcVec = SrcVecIdx >= NumOfSrcRegs ? V2 : V1; - unsigned ExtractIdx = (SrcVecIdx % NumOfSrcRegs) * NumOpElts; - SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, SrcVec, - DAG.getVectorIdxConstant(ExtractIdx, DL)); - SubVec = convertFromScalableVector(OneRegVT, SubVec, DAG, Subtarget); - SubVec = DAG.getVectorShuffle(OneRegVT, DL, SubVec, SubVec, SrcSubMask); - SubRegs[RegCnt] = convertToScalableVector(M1VT, SubVec, DAG, Subtarget); - PrevCnt = RegCnt; - ++RegCnt; - }, - [&, &DAG = DAG](ArrayRef<int> SrcSubMask, unsigned Idx1, unsigned Idx2) { - if (PrevCnt + 1 == RegCnt) - ++RegCnt; - SDValue SubVec1 = SubRegs[PrevCnt + 1]; - if (!SubVec1) { - SDValue SrcVec = Idx1 >= NumOfSrcRegs ? V2 : V1; - unsigned ExtractIdx = (Idx1 % NumOfSrcRegs) * NumOpElts; - SubVec1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, SrcVec, - DAG.getVectorIdxConstant(ExtractIdx, DL)); - } - SubVec1 = convertFromScalableVector(OneRegVT, SubVec1, DAG, Subtarget); - SDValue SrcVec = Idx2 >= NumOfSrcRegs ? V2 : V1; - unsigned ExtractIdx = (Idx2 % NumOfSrcRegs) * NumOpElts; - SDValue SubVec2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, SrcVec, - DAG.getVectorIdxConstant(ExtractIdx, DL)); - SubVec2 = convertFromScalableVector(OneRegVT, SubVec2, DAG, Subtarget); - SubVec1 = - DAG.getVectorShuffle(OneRegVT, DL, SubVec1, SubVec2, SrcSubMask); - SubVec1 = convertToScalableVector(M1VT, SubVec1, DAG, Subtarget); - SubRegs[PrevCnt + 1] = SubVec1; - }); - assert(RegCnt == NumOfDestRegs && "Whole vector must be processed"); - SDValue Vec = DAG.getUNDEF(ContainerVT); - for (auto [I, V] : enumerate(SubRegs)) { - if (!V) + for (unsigned DstVecIdx = 0 ; DstVecIdx < OutMasks.size(); DstVecIdx++) { + auto &[SrcVecIdx, SrcSubMask] = OutMasks[DstVecIdx]; + if (SrcVecIdx == -1) continue; - unsigned InsertIdx = I * NumOpElts; - - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Vec, V, + unsigned ExtractIdx = (SrcVecIdx % VRegsPerSrc) * NumOpElts; + SDValue SrcVec = (unsigned)SrcVecIdx >= VRegsPerSrc ? V2 : V1; + SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, SrcVec, + DAG.getVectorIdxConstant(ExtractIdx, DL)); + SubVec = convertFromScalableVector(OneRegVT, SubVec, DAG, Subtarget); + SubVec = DAG.getVectorShuffle(OneRegVT, DL, SubVec, SubVec, SrcSubMask); + SubVec = convertToScalableVector(M1VT, SubVec, DAG, Subtarget); + unsigned InsertIdx = DstVecIdx * NumOpElts; + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Vec, SubVec, DAG.getVectorIdxConstant(InsertIdx, DL)); } return convertFromScalableVector(VT, Vec, DAG, Subtarget); @@ -10165,7 +10154,10 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op, case ISD::VP_REDUCE_AND: { // vcpop ~x == 0 SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL); - Vec = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Vec, TrueMask, VL); + if (IsVP || VecVT.isFixedLengthVector()) + Vec = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Vec, TrueMask, VL); + else + Vec = DAG.getNode(ISD::XOR, DL, ContainerVT, Vec, TrueMask); Vec = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Vec, Mask, VL); CC = ISD::SETEQ; break; @@ -12674,8 +12666,7 @@ SDValue RISCVTargetLowering::lowerGET_ROUNDING(SDValue Op, const MVT XLenVT = Subtarget.getXLenVT(); SDLoc DL(Op); SDValue Chain = Op->getOperand(0); - SDValue SysRegNo = DAG.getTargetConstant( - RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT); + SDValue SysRegNo = DAG.getTargetConstant(RISCVSysReg::frm, DL, XLenVT); SDVTList VTs = DAG.getVTList(XLenVT, MVT::Other); SDValue RM = DAG.getNode(RISCVISD::READ_CSR, DL, VTs, Chain, SysRegNo); @@ -12706,8 +12697,7 @@ SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op, SDLoc DL(Op); SDValue Chain = Op->getOperand(0); SDValue RMValue = Op->getOperand(1); - SDValue SysRegNo = DAG.getTargetConstant( - RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT); + SDValue SysRegNo = DAG.getTargetConstant(RISCVSysReg::frm, DL, XLenVT); // Encoding used for rounding mode in RISC-V differs from that used in // FLT_ROUNDS. To convert it the C rounding mode is used as an index in @@ -12910,15 +12900,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, SDValue LoCounter, HiCounter; MVT XLenVT = Subtarget.getXLenVT(); if (N->getOpcode() == ISD::READCYCLECOUNTER) { - LoCounter = DAG.getTargetConstant( - RISCVSysReg::lookupSysRegByName("CYCLE")->Encoding, DL, XLenVT); - HiCounter = DAG.getTargetConstant( - RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding, DL, XLenVT); + LoCounter = DAG.getTargetConstant(RISCVSysReg::cycle, DL, XLenVT); + HiCounter = DAG.getTargetConstant(RISCVSysReg::cycleh, DL, XLenVT); } else { - LoCounter = DAG.getTargetConstant( - RISCVSysReg::lookupSysRegByName("TIME")->Encoding, DL, XLenVT); - HiCounter = DAG.getTargetConstant( - RISCVSysReg::lookupSysRegByName("TIMEH")->Encoding, DL, XLenVT); + LoCounter = DAG.getTargetConstant(RISCVSysReg::time, DL, XLenVT); + HiCounter = DAG.getTargetConstant(RISCVSysReg::timeh, DL, XLenVT); } SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32, MVT::Other); SDValue RCW = DAG.getNode(RISCVISD::READ_COUNTER_WIDE, DL, VTs, @@ -18397,6 +18383,15 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift( auto *C1 = dyn_cast<ConstantSDNode>(N0->getOperand(1)); auto *C2 = dyn_cast<ConstantSDNode>(N->getOperand(1)); + + // Bail if we might break a sh{1,2,3}add pattern. + if (Subtarget.hasStdExtZba() && C2 && C2->getZExtValue() >= 1 && + C2->getZExtValue() <= 3 && N->hasOneUse() && + N->user_begin()->getOpcode() == ISD::ADD && + !isUsedByLdSt(*N->user_begin(), nullptr) && + !isa<ConstantSDNode>(N->user_begin()->getOperand(1))) + return false; + if (C1 && C2) { const APInt &C1Int = C1->getAPIntValue(); APInt ShiftedC1Int = C1Int << C2->getAPIntValue(); @@ -20278,13 +20273,11 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, for (auto &Reg : RegsToPass) Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType())); - if (!IsTailCall) { - // Add a register mask operand representing the call-preserved registers. - const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); - const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv); - assert(Mask && "Missing call preserved mask for calling convention"); - Ops.push_back(DAG.getRegisterMask(Mask)); - } + // Add a register mask operand representing the call-preserved registers. + const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); + const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv); + assert(Mask && "Missing call preserved mask for calling convention"); + Ops.push_back(DAG.getRegisterMask(Mask)); // Glue the call to the argument copies, if any. if (Glue.getNode()) diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp index 7598583..1fd130d 100644 --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -627,7 +627,7 @@ public: return MI; } - void setAVL(VSETVLIInfo Info) { + void setAVL(const VSETVLIInfo &Info) { assert(Info.isValid()); if (Info.isUnknown()) setUnknown(); @@ -1223,7 +1223,8 @@ bool RISCVInsertVSETVLI::needVSETVLI(const DemandedFields &Used, // If we don't use LMUL or the SEW/LMUL ratio, then adjust LMUL so that we // maintain the SEW/LMUL ratio. This allows us to eliminate VL toggles in more // places. -static VSETVLIInfo adjustIncoming(VSETVLIInfo PrevInfo, VSETVLIInfo NewInfo, +static VSETVLIInfo adjustIncoming(const VSETVLIInfo &PrevInfo, + const VSETVLIInfo &NewInfo, DemandedFields &Demanded) { VSETVLIInfo Info = NewInfo; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td index ae969bff8..349bc36 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td @@ -23,7 +23,9 @@ def SDT_RISCVSplitF64 : SDTypeProfile<2, 1, [SDTCisVT<0, i32>, SDTCisVT<2, f64>]>; def RISCVBuildPairF64 : SDNode<"RISCVISD::BuildPairF64", SDT_RISCVBuildPairF64>; +def : GINodeEquiv<G_MERGE_VALUES, RISCVBuildPairF64>; def RISCVSplitF64 : SDNode<"RISCVISD::SplitF64", SDT_RISCVSplitF64>; +def : GINodeEquiv<G_UNMERGE_VALUES, RISCVSplitF64>; def AddrRegImmINX : ComplexPattern<iPTR, 2, "SelectAddrRegImmRV32Zdinx">; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td index 37b29ed..942ced8 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td @@ -536,27 +536,23 @@ multiclass VPatTernaryVMAQA_VV_VX<string intrinsic, string instruction, //===----------------------------------------------------------------------===// let Predicates = [HasVendorXTHeadBa] in { -def : Pat<(add (XLenVT GPR:$rs1), (shl GPR:$rs2, uimm2:$uimm2)), +def : Pat<(add_like_non_imm12 (shl GPR:$rs2, uimm2:$uimm2), (XLenVT GPR:$rs1)), + (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>; +def : Pat<(XLenVT (riscv_shl_add GPR:$rs2, uimm2:$uimm2, GPR:$rs1)), (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>; -def : Pat<(XLenVT (riscv_shl_add GPR:$rs1, uimm2:$uimm2, GPR:$rs2)), - (TH_ADDSL GPR:$rs2, GPR:$rs1, uimm2:$uimm2)>; // Reuse complex patterns from StdExtZba -def : Pat<(add_non_imm12 sh1add_op:$rs1, (XLenVT GPR:$rs2)), - (TH_ADDSL GPR:$rs2, sh1add_op:$rs1, 1)>; -def : Pat<(add_non_imm12 sh2add_op:$rs1, (XLenVT GPR:$rs2)), - (TH_ADDSL GPR:$rs2, sh2add_op:$rs1, 2)>; -def : Pat<(add_non_imm12 sh3add_op:$rs1, (XLenVT GPR:$rs2)), - (TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>; - -def : Pat<(add (XLenVT GPR:$r), CSImm12MulBy4:$i), +def : Pat<(add_like_non_imm12 sh1add_op:$rs2, (XLenVT GPR:$rs1)), + (TH_ADDSL GPR:$rs1, sh1add_op:$rs2, 1)>; +def : Pat<(add_like_non_imm12 sh2add_op:$rs2, (XLenVT GPR:$rs1)), + (TH_ADDSL GPR:$rs1, sh2add_op:$rs2, 2)>; +def : Pat<(add_like_non_imm12 sh3add_op:$rs2, (XLenVT GPR:$rs1)), + (TH_ADDSL GPR:$rs1, sh3add_op:$rs2, 3)>; + +def : Pat<(add_like (XLenVT GPR:$r), CSImm12MulBy4:$i), (TH_ADDSL GPR:$r, (XLenVT (ADDI (XLenVT X0), CSImm12MulBy4:$i)), 2)>; -def : Pat<(add (XLenVT GPR:$r), CSImm12MulBy8:$i), +def : Pat<(add_like (XLenVT GPR:$r), CSImm12MulBy8:$i), (TH_ADDSL GPR:$r, (XLenVT (ADDI (XLenVT X0), CSImm12MulBy8:$i)), 3)>; - -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)), - (SLLI (XLenVT (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), - (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)), 3)>; } // Predicates = [HasVendorXTHeadBa] let Predicates = [HasVendorXTHeadBb] in { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td index 05b5591..6f15646 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td @@ -21,6 +21,13 @@ def uimm5nonzero : RISCVOp<XLenVT>, let OperandType = "OPERAND_UIMM5_NONZERO"; } +def uimm5gt3 : RISCVOp<XLenVT>, ImmLeaf<XLenVT, + [{return (Imm > 3) && isUInt<5>(Imm);}]> { + let ParserMatchClass = UImmAsmOperand<5, "GT3">; + let DecoderMethod = "decodeUImmOperand<5>"; + let OperandType = "OPERAND_UIMM5_GT3"; +} + def uimm11 : RISCVUImmLeafOp<11>; //===----------------------------------------------------------------------===// @@ -132,6 +139,33 @@ class QCIStoreMultiple<bits<2> funct2, DAGOperand InTyRs2, string opcodestr> let Inst{31-25} = {funct2, imm{6-2}}; } +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +class QCILICC<bits<3> funct3, bits<2> funct2, DAGOperand InTyRs2, string opcodestr> + : RVInstRBase<funct3, OPC_CUSTOM_2, (outs GPRNoX0:$rd_wb), + (ins GPRNoX0:$rd, GPRNoX0:$rs1, InTyRs2:$rs2, simm5:$simm), + opcodestr, "$rd, $rs1, $rs2, $simm"> { + let Constraints = "$rd = $rd_wb"; + bits<5> simm; + + let Inst{31-25} = {simm, funct2}; +} + +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +class QCIMVCC<bits<3> funct3, string opcodestr> + : RVInstR4<0b00, funct3, OPC_CUSTOM_2, (outs GPRNoX0:$rd), + (ins GPRNoX0:$rs1, GPRNoX0:$rs2, GPRNoX0:$rs3), + opcodestr, "$rd, $rs1, $rs2, $rs3">; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +class QCIMVCCI<bits<3> funct3, string opcodestr, DAGOperand immType> + : RVInstR4<0b10, funct3, OPC_CUSTOM_2, (outs GPRNoX0:$rd), + (ins GPRNoX0:$rs1, immType:$imm, GPRNoX0:$rs3), + opcodestr, "$rd, $rs1, $imm, $rs3"> { + bits<5> imm; + + let rs2 = imm; +} + //===----------------------------------------------------------------------===// // Instructions //===----------------------------------------------------------------------===// @@ -184,6 +218,37 @@ let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in { } // hasSideEffects = 0, mayLoad = 0, mayStore = 0 } // Predicates = [HasVendorXqcia, IsRV32], DecoderNamespace = "Xqcia" +let Predicates = [HasVendorXqciac, IsRV32], DecoderNamespace = "Xqciac" in { +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in { + def QC_C_MULADDI : RVInst16CL<0b001, 0b10, (outs GPRC:$rd_wb), + (ins GPRC:$rd, GPRC:$rs1, uimm5:$uimm), + "qc.c.muladdi", "$rd, $rs1, $uimm"> { + let Constraints = "$rd = $rd_wb"; + bits<5> uimm; + + let Inst{12-10} = uimm{3-1}; + let Inst{6} = uimm{0}; + let Inst{5} = uimm{4}; + } + + def QC_MULADDI : RVInstI<0b110, OPC_CUSTOM_0, (outs GPRNoX0:$rd_wb), + (ins GPRNoX0:$rd, GPRNoX0:$rs1, simm12:$imm12), + "qc.muladdi", "$rd, $rs1, $imm12"> { + let Constraints = "$rd = $rd_wb"; + } + + def QC_SHLADD : RVInstRBase<0b011, OPC_CUSTOM_0, (outs GPRNoX0:$rd), + (ins GPRNoX0:$rs1, GPRNoX0:$rs2, uimm5gt3:$shamt), + "qc.shladd", "$rd, $rs1, $rs2, $shamt"> { + bits<5> shamt; + + let Inst{31-30} = 0b01; + let Inst{29-25} = shamt; + } + +} // hasSideEffects = 0, mayLoad = 0, mayStore = 0 +} // Predicates = [HasVendorXqciac, IsRV32], DecoderNamespace = "Xqciac" + let Predicates = [HasVendorXqcics, IsRV32], DecoderNamespace = "Xqcics" in { def QC_SELECTIIEQ : QCISELECTIICC <0b010, "qc.selectiieq">; def QC_SELECTIINE : QCISELECTIICC <0b011, "qc.selectiine">; @@ -205,6 +270,48 @@ let Predicates = [HasVendorXqcilsm, IsRV32], DecoderNamespace = "Xqcilsm" in { def QC_LWMI : QCILoadMultiple<0b01, uimm5nonzero, "qc.lwmi">; } // Predicates = [HasVendorXqcilsm, IsRV32], DecoderNamespace = "Xqcilsm" +let Predicates = [HasVendorXqcicli, IsRV32], DecoderNamespace = "Xqcicli" in { + def QC_LIEQ : QCILICC<0b000, 0b01, GPRNoX0, "qc.lieq">; + def QC_LINE : QCILICC<0b001, 0b01, GPRNoX0, "qc.line">; + def QC_LILT : QCILICC<0b100, 0b01, GPRNoX0, "qc.lilt">; + def QC_LIGE : QCILICC<0b101, 0b01, GPRNoX0, "qc.lige">; + def QC_LILTU : QCILICC<0b110, 0b01, GPRNoX0, "qc.liltu">; + def QC_LIGEU : QCILICC<0b111, 0b01, GPRNoX0, "qc.ligeu">; + + def QC_LIEQI : QCILICC<0b000, 0b11, simm5, "qc.lieqi">; + def QC_LINEI : QCILICC<0b001, 0b11, simm5, "qc.linei">; + def QC_LILTI : QCILICC<0b100, 0b11, simm5, "qc.lilti">; + def QC_LIGEI : QCILICC<0b101, 0b11, simm5, "qc.ligei">; + def QC_LILTUI : QCILICC<0b110, 0b11, uimm5, "qc.liltui">; + def QC_LIGEUI : QCILICC<0b111, 0b11, uimm5, "qc.ligeui">; +} // Predicates = [HasVendorXqcicli, IsRV32], DecoderNamespace = "Xqcicli" + +let Predicates = [HasVendorXqcicm, IsRV32], DecoderNamespace = "Xqcicm" in { +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in + def QC_C_MVEQZ : RVInst16CL<0b101, 0b10, (outs GPRC:$rd_wb), + (ins GPRC:$rd, GPRC:$rs1), + "qc.c.mveqz", "$rd, $rs1"> { + let Constraints = "$rd = $rd_wb"; + + let Inst{12-10} = 0b011; + let Inst{6-5} = 0b00; + } + + def QC_MVEQ : QCIMVCC<0b000, "qc.mveq">; + def QC_MVNE : QCIMVCC<0b001, "qc.mvne">; + def QC_MVLT : QCIMVCC<0b100, "qc.mvlt">; + def QC_MVGE : QCIMVCC<0b101, "qc.mvge">; + def QC_MVLTU : QCIMVCC<0b110, "qc.mvltu">; + def QC_MVGEU : QCIMVCC<0b111, "qc.mvgeu">; + + def QC_MVEQI : QCIMVCCI<0b000, "qc.mveqi", simm5>; + def QC_MVNEI : QCIMVCCI<0b001, "qc.mvnei", simm5>; + def QC_MVLTI : QCIMVCCI<0b100, "qc.mvlti", simm5>; + def QC_MVGEI : QCIMVCCI<0b101, "qc.mvgei", simm5>; + def QC_MVLTUI : QCIMVCCI<0b110, "qc.mvltui", uimm5>; + def QC_MVGEUI : QCIMVCCI<0b111, "qc.mvgeui", uimm5>; +} // Predicates = [HasVendorXqcicm, IsRV32], DecoderNamespace = "Xqcicm" + //===----------------------------------------------------------------------===// // Aliases //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVPfmCounters.td b/llvm/lib/Target/RISCV/RISCVPfmCounters.td new file mode 100644 index 0000000..013e789 --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVPfmCounters.td @@ -0,0 +1,18 @@ +//===---- RISCVPfmCounters.td - RISC-V Hardware Counters ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This describes the available hardware counters for RISC-V. +// +//===----------------------------------------------------------------------===// + +def CpuCyclesPfmCounter : PfmCounter<"CYCLES">; + +def DefaultPfmCounters : ProcPfmCounters { + let CycleCounter = CpuCyclesPfmCounter; +} +def : PfmCountersDefaultBinding<DefaultPfmCounters>; diff --git a/llvm/lib/Target/RISCV/RISCVProcessors.td b/llvm/lib/Target/RISCV/RISCVProcessors.td index 61c7c21..6dfed7dd 100644 --- a/llvm/lib/Target/RISCV/RISCVProcessors.td +++ b/llvm/lib/Target/RISCV/RISCVProcessors.td @@ -321,6 +321,25 @@ def SIFIVE_P470 : RISCVProcessorModel<"sifive-p470", SiFiveP400Model, [TuneNoSinkSplatOperands, TuneVXRMPipelineFlush])>; +defvar SiFiveP500TuneFeatures = [TuneNoDefaultUnroll, + TuneConditionalCompressedMoveFusion, + TuneLUIADDIFusion, + TuneAUIPCADDIFusion, + TunePostRAScheduler]; + +def SIFIVE_P550 : RISCVProcessorModel<"sifive-p550", NoSchedModel, + [Feature64Bit, + FeatureStdExtI, + FeatureStdExtZifencei, + FeatureStdExtM, + FeatureStdExtA, + FeatureStdExtF, + FeatureStdExtD, + FeatureStdExtC, + FeatureStdExtZba, + FeatureStdExtZbb], + SiFiveP500TuneFeatures>; + def SIFIVE_P670 : RISCVProcessorModel<"sifive-p670", SiFiveP600Model, !listconcat(RVA22U64Features, [FeatureStdExtV, diff --git a/llvm/lib/Target/RISCV/RISCVSchedSiFiveP400.td b/llvm/lib/Target/RISCV/RISCVSchedSiFiveP400.td index a86c255..396cbe2c 100644 --- a/llvm/lib/Target/RISCV/RISCVSchedSiFiveP400.td +++ b/llvm/lib/Target/RISCV/RISCVSchedSiFiveP400.td @@ -182,7 +182,7 @@ def P400WriteCMOV : SchedWriteRes<[SiFiveP400Branch, SiFiveP400IEXQ1]> { } def : InstRW<[P400WriteCMOV], (instrs PseudoCCMOVGPRNoX0)>; -let Latency = 3 in { +let Latency = 2 in { // Integer multiplication def : WriteRes<WriteIMul, [SiFiveP400MulDiv]>; def : WriteRes<WriteIMul32, [SiFiveP400MulDiv]>; diff --git a/llvm/lib/Target/RISCV/RISCVSchedule.td b/llvm/lib/Target/RISCV/RISCVSchedule.td index 7946a74..ceaeb85 100644 --- a/llvm/lib/Target/RISCV/RISCVSchedule.td +++ b/llvm/lib/Target/RISCV/RISCVSchedule.td @@ -237,6 +237,7 @@ def : ReadAdvance<ReadFCvtF16ToI32, 0>; def : ReadAdvance<ReadFDiv16, 0>; def : ReadAdvance<ReadFCmp16, 0>; def : ReadAdvance<ReadFMA16, 0>; +def : ReadAdvance<ReadFMA16Addend, 0>; def : ReadAdvance<ReadFMinMax16, 0>; def : ReadAdvance<ReadFMul16, 0>; def : ReadAdvance<ReadFSGNJ16, 0>; diff --git a/llvm/lib/Target/RISCV/RISCVSystemOperands.td b/llvm/lib/Target/RISCV/RISCVSystemOperands.td index d85b4a9..4c86103 100644 --- a/llvm/lib/Target/RISCV/RISCVSystemOperands.td +++ b/llvm/lib/Target/RISCV/RISCVSystemOperands.td @@ -19,12 +19,6 @@ include "llvm/TableGen/SearchableTable.td" class SysReg<string name, bits<12> op> { string Name = name; - // A maximum of one alias is supported right now. - string AltName = name; - // A maximum of one deprecated name is supported right now. Unlike the - // `AltName` alias, a `DeprecatedName` generates a diagnostic when the name is - // used to encourage software to migrate away from the name. - string DeprecatedName = ""; bits<12> Encoding = op; // FIXME: add these additional fields when needed. // Privilege Access: Read and Write = 0, 1, 2; Read-Only = 3. @@ -37,14 +31,16 @@ class SysReg<string name, bits<12> op> { // bits<6> Number = op{5 - 0}; code FeaturesRequired = [{ {} }]; bit isRV32Only = 0; + bit isAltName = 0; + bit isDeprecatedName = 0; } def SysRegsList : GenericTable { let FilterClass = "SysReg"; // FIXME: add "ReadWrite", "Mode", "Extra", "Number" fields when needed. let Fields = [ - "Name", "AltName", "DeprecatedName", "Encoding", "FeaturesRequired", - "isRV32Only", + "Name", "Encoding", "FeaturesRequired", + "isRV32Only", "isAltName", "isDeprecatedName" ]; let PrimaryKey = [ "Encoding" ]; @@ -52,19 +48,15 @@ def SysRegsList : GenericTable { let PrimaryKeyReturnRange = true; } -def lookupSysRegByName : SearchIndex { - let Table = SysRegsList; - let Key = [ "Name" ]; -} - -def lookupSysRegByAltName : SearchIndex { - let Table = SysRegsList; - let Key = [ "AltName" ]; +def SysRegEncodings : GenericEnum { + let FilterClass = "SysReg"; + let NameField = "Name"; + let ValueField = "Encoding"; } -def lookupSysRegByDeprecatedName : SearchIndex { +def lookupSysRegByName : SearchIndex { let Table = SysRegsList; - let Key = [ "DeprecatedName" ]; + let Key = [ "Name" ]; } // The following CSR encodings match those given in Tables 2.2, @@ -123,15 +115,17 @@ def : SysReg<"senvcfg", 0x10A>; def : SysReg<"sscratch", 0x140>; def : SysReg<"sepc", 0x141>; def : SysReg<"scause", 0x142>; -let DeprecatedName = "sbadaddr" in def : SysReg<"stval", 0x143>; +let isDeprecatedName = 1 in +def : SysReg<"sbadaddr", 0x143>; def : SysReg<"sip", 0x144>; //===----------------------------------------------------------------------===// // Supervisor Protection and Translation //===----------------------------------------------------------------------===// -let DeprecatedName = "sptbr" in def : SysReg<"satp", 0x180>; +let isDeprecatedName = 1 in +def : SysReg<"sptbr", 0x180>; //===----------------------------------------------------------------------===// // Quality-of-Service(QoS) Identifiers (Ssqosid) @@ -245,8 +239,9 @@ def : SysReg<"mstatush", 0x310>; def : SysReg<"mscratch", 0x340>; def : SysReg<"mepc", 0x341>; def : SysReg<"mcause", 0x342>; -let DeprecatedName = "mbadaddr" in def : SysReg<"mtval", 0x343>; +let isDeprecatedName = 1 in +def : SysReg<"mbadaddr", 0x343>; def : SysReg<"mip", 0x344>; def : SysReg<"mtinst", 0x34A>; def : SysReg<"mtval2", 0x34B>; @@ -298,8 +293,9 @@ foreach i = 3...31 in //===----------------------------------------------------------------------===// // Machine Counter Setup //===----------------------------------------------------------------------===// -let AltName = "mucounteren" in // Privileged spec v1.9.1 Name def : SysReg<"mcountinhibit", 0x320>; +let isAltName = 1 in +def : SysReg<"mucounteren", 0x320>; // mhpmevent3-mhpmevent31 at 0x323-0x33F. foreach i = 3...31 in @@ -323,7 +319,10 @@ def : SysReg<"tselect", 0x7A0>; def : SysReg<"tdata1", 0x7A1>; def : SysReg<"tdata2", 0x7A2>; def : SysReg<"tdata3", 0x7A3>; +def : SysReg<"tinfo", 0x7A4>; +def : SysReg<"tcontrol", 0x7A5>; def : SysReg<"mcontext", 0x7A8>; +def : SysReg<"mscontext", 0x7AA>; //===----------------------------------------------------------------------===// // Debug Mode Registers @@ -333,8 +332,9 @@ def : SysReg<"dpc", 0x7B1>; // "dscratch" is an alternative name for "dscratch0" which appeared in earlier // drafts of the RISC-V debug spec -let AltName = "dscratch" in def : SysReg<"dscratch0", 0x7B2>; +let isAltName = 1 in +def : SysReg<"dscratch", 0x7B2>; def : SysReg<"dscratch1", 0x7B3>; //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 49192bd..850d624 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1663,7 +1663,7 @@ InstructionCost RISCVTTIImpl::getStoreImmCost(Type *Ty, return 0; if (OpInfo.isUniform()) - // vmv.x.i, vmv.v.x, or vfmv.v.f + // vmv.v.i, vmv.v.x, or vfmv.v.f // We ignore the cost of the scalar constant materialization to be consistent // with how we treat scalar constants themselves just above. return 1; @@ -2329,6 +2329,15 @@ unsigned RISCVTTIImpl::getMaximumVF(unsigned ElemWidth, unsigned Opcode) const { return std::max<unsigned>(1U, RegWidth.getFixedValue() / ElemWidth); } +TTI::AddressingModeKind +RISCVTTIImpl::getPreferredAddressingMode(const Loop *L, + ScalarEvolution *SE) const { + if (ST->hasVendorXCVmem() && !ST->is64Bit()) + return TTI::AMK_PostIndexed; + + return BasicTTIImplBase::getPreferredAddressingMode(L, SE); +} + bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1, const TargetTransformInfo::LSRCost &C2) { // RISC-V specific here are "instruction number 1st priority". @@ -2549,16 +2558,21 @@ RISCVTTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const { TTI::MemCmpExpansionOptions Options; // TODO: Enable expansion when unaligned access is not supported after we fix // issues in ExpandMemcmp. - if (!(ST->enableUnalignedScalarMem() && - (ST->hasStdExtZbb() || ST->hasStdExtZbkb() || IsZeroCmp))) + if (!ST->enableUnalignedScalarMem()) + return Options; + + if (!ST->hasStdExtZbb() && !ST->hasStdExtZbkb() && !IsZeroCmp) return Options; Options.AllowOverlappingLoads = true; Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize); Options.NumLoadsPerBlock = Options.MaxNumLoads; - if (ST->is64Bit()) + if (ST->is64Bit()) { Options.LoadSizes = {8, 4, 2, 1}; - else + Options.AllowedTailExpansions = {3, 5, 6}; + } else { Options.LoadSizes = {4, 2, 1}; + Options.AllowedTailExpansions = {3}; + } return Options; } diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index bd90bfe..9b36439 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -388,6 +388,9 @@ public: llvm_unreachable("unknown register class"); } + TTI::AddressingModeKind getPreferredAddressingMode(const Loop *L, + ScalarEvolution *SE) const; + unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const { if (Vector) return RISCVRegisterClass::VRRC; diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index 4e3212c..ad61a77 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -50,7 +50,10 @@ public: StringRef getPassName() const override { return PASS_NAME; } private: - bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI); + std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp); + /// Returns the largest common VL MachineOperand that may be used to optimize + /// MI. Returns std::nullopt if it failed to find a suitable VL. + std::optional<MachineOperand> checkUsers(MachineInstr &MI); bool tryReduceVL(MachineInstr &MI); bool isCandidate(const MachineInstr &MI) const; }; @@ -76,11 +79,6 @@ static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) { /// Represents the EMUL and EEW of a MachineOperand. struct OperandInfo { - enum class State { - Unknown, - Known, - } S; - // Represent as 1,2,4,8, ... and fractional indicator. This is because // EMUL can take on values that don't map to RISCVII::VLMUL values exactly. // For example, a mask operand can have an EMUL less than MF8. @@ -89,34 +87,32 @@ struct OperandInfo { unsigned Log2EEW; OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW) - : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) { - } + : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {} OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW) - : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {} + : EMUL(EMUL), Log2EEW(Log2EEW) {} - OperandInfo() : S(State::Unknown) {} + OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {} - bool isUnknown() const { return S == State::Unknown; } - bool isKnown() const { return S == State::Known; } + OperandInfo() = delete; static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) { - assert(A.isKnown() && B.isKnown() && "Both operands must be known"); - return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first && A.EMUL->second == B.EMUL->second; } + static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) { + return A.Log2EEW == B.Log2EEW; + } + void print(raw_ostream &OS) const { - if (isUnknown()) { - OS << "Unknown"; - return; - } - assert(EMUL && "Expected EMUL to have value"); - OS << "EMUL: m"; - if (EMUL->second) - OS << "f"; - OS << EMUL->first; + if (EMUL) { + OS << "EMUL: m"; + if (EMUL->second) + OS << "f"; + OS << EMUL->first; + } else + OS << "EMUL: unknown\n"; OS << ", EEW: " << (1 << Log2EEW); } }; @@ -127,30 +123,18 @@ static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) { return OS; } -namespace llvm { -namespace RISCVVType { -/// Return the RISCVII::VLMUL that is two times VLMul. -/// Precondition: VLMul is not LMUL_RESERVED or LMUL_8. -static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) { - switch (VLMul) { - case RISCVII::VLMUL::LMUL_F8: - return RISCVII::VLMUL::LMUL_F4; - case RISCVII::VLMUL::LMUL_F4: - return RISCVII::VLMUL::LMUL_F2; - case RISCVII::VLMUL::LMUL_F2: - return RISCVII::VLMUL::LMUL_1; - case RISCVII::VLMUL::LMUL_1: - return RISCVII::VLMUL::LMUL_2; - case RISCVII::VLMUL::LMUL_2: - return RISCVII::VLMUL::LMUL_4; - case RISCVII::VLMUL::LMUL_4: - return RISCVII::VLMUL::LMUL_8; - case RISCVII::VLMUL::LMUL_8: - default: - llvm_unreachable("Could not multiply VLMul by 2"); - } +LLVM_ATTRIBUTE_UNUSED +static raw_ostream &operator<<(raw_ostream &OS, + const std::optional<OperandInfo> &OI) { + if (OI) + OI->print(OS); + else + OS << "nullopt"; + return OS; } +namespace llvm { +namespace RISCVVType { /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and /// SEW are from the TSFlags of MI. static std::pair<unsigned, bool> @@ -180,24 +164,22 @@ getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { } // end namespace RISCVVType } // end namespace llvm -/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2). -/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI. -static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor, - const MachineInstr &MI, - const MachineOperand &MO) { - RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); +/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2). +/// SEW comes from TSFlags of MI. +static unsigned getIntegerExtensionOperandEEW(unsigned Factor, + const MachineInstr &MI, + const MachineOperand &MO) { unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); + return MILog2SEW; unsigned MISEW = 1 << MILog2SEW; unsigned EEW = MISEW / Factor; unsigned Log2EEW = Log2_32(EEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), - Log2EEW); + return Log2EEW; } /// Check whether MO is a mask operand of MI. @@ -211,18 +193,15 @@ static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; } -/// Return the OperandInfo for MO. -static OperandInfo getOperandInfo(const MachineOperand &MO, - const MachineRegisterInfo *MRI) { +static std::optional<unsigned> +getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { const MachineInstr &MI = *MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); assert(RVV && "Could not find MI in PseudoTable"); - // MI has a VLMUL and SEW associated with it. The RVV specification defines - // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and - // MI.SEW. - RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); + // MI has a SEW associated with it. The RVV specification defines + // the EEW of each operand and definition in relation to MI.SEW. unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); @@ -233,13 +212,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // since they must preserve the entire register content. if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() && (MO.getReg() != RISCV::NoRegister)) - return {}; + return std::nullopt; bool IsMODef = MO.getOperandNo() == 0; - // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL + // All mask operands have EEW=1 if (isMaskOperand(MI, MO, MRI)) - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; // switch against BaseInstr to reduce number of cases that need to be // considered. @@ -256,55 +235,65 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Loads and Stores // Vector Unit-Stride Instructions // Vector Strided Instructions - /// Dest EEW encoded in the instruction and EMUL=(EEW/SEW)*LMUL + /// Dest EEW encoded in the instruction + case RISCV::VLM_V: + case RISCV::VSM_V: + return 0; + case RISCV::VLE8_V: case RISCV::VSE8_V: + case RISCV::VLSE8_V: case RISCV::VSSE8_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3); + return 3; + case RISCV::VLE16_V: case RISCV::VSE16_V: + case RISCV::VLSE16_V: case RISCV::VSSE16_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4); + return 4; + case RISCV::VLE32_V: case RISCV::VSE32_V: + case RISCV::VLSE32_V: case RISCV::VSSE32_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5); + return 5; + case RISCV::VLE64_V: case RISCV::VSE64_V: + case RISCV::VLSE64_V: case RISCV::VSSE64_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6); + return 6; // Vector Indexed Instructions // vs(o|u)xei<eew>.v - // Dest/Data (operand 0) EEW=SEW, EMUL=LMUL. Source EEW=<eew> and - // EMUL=(EEW/SEW)*LMUL. + // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>. case RISCV::VLUXEI8_V: case RISCV::VLOXEI8_V: case RISCV::VSUXEI8_V: case RISCV::VSOXEI8_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3); + return MILog2SEW; + return 3; } case RISCV::VLUXEI16_V: case RISCV::VLOXEI16_V: case RISCV::VSUXEI16_V: case RISCV::VSOXEI16_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4); + return MILog2SEW; + return 4; } case RISCV::VLUXEI32_V: case RISCV::VLOXEI32_V: case RISCV::VSUXEI32_V: case RISCV::VSOXEI32_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5); + return MILog2SEW; + return 5; } case RISCV::VLUXEI64_V: case RISCV::VLOXEI64_V: case RISCV::VSUXEI64_V: case RISCV::VSOXEI64_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6); + return MILog2SEW; + return 6; } // Vector Integer Arithmetic Instructions @@ -318,7 +307,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VRSUB_VX: // Vector Bitwise Logical Instructions // Vector Single-Width Shift Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VAND_VI: case RISCV::VAND_VV: case RISCV::VAND_VX: @@ -338,7 +327,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VSRA_VV: case RISCV::VSRA_VX: // Vector Integer Min/Max Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMINU_VV: case RISCV::VMINU_VX: case RISCV::VMIN_VV: @@ -348,7 +337,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMAX_VV: case RISCV::VMAX_VX: // Vector Single-Width Integer Multiply Instructions - // Source and Dest EEW=SEW and EMUL=LMUL. + // Source and Dest EEW=SEW. case RISCV::VMUL_VV: case RISCV::VMUL_VX: case RISCV::VMULH_VV: @@ -358,7 +347,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMULHSU_VV: case RISCV::VMULHSU_VX: // Vector Integer Divide Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VDIVU_VV: case RISCV::VDIVU_VX: case RISCV::VDIV_VV: @@ -368,7 +357,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VREM_VV: case RISCV::VREM_VX: // Vector Single-Width Integer Multiply-Add Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMACC_VV: case RISCV::VMACC_VX: case RISCV::VNMSAC_VV: @@ -379,8 +368,8 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VNMSUB_VX: // Vector Integer Merge Instructions // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= - // (EEW/SEW)*LMUL. Mask operand is handled before this switch. + // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled + // before this switch. case RISCV::VMERGE_VIM: case RISCV::VMERGE_VVM: case RISCV::VMERGE_VXM: @@ -393,7 +382,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Fixed-Point Arithmetic Instructions // Vector Single-Width Saturating Add and Subtract // Vector Single-Width Averaging Add and Subtract - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMV_V_I: case RISCV::VMV_V_V: case RISCV::VMV_V_X: @@ -415,8 +404,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VASUBU_VX: case RISCV::VASUB_VV: case RISCV::VASUB_VX: + // Vector Single-Width Fractional Multiply with Rounding and Saturation + // EEW=SEW. The instruction produces 2*SEW product internally but + // saturates to fit into SEW bits. + case RISCV::VSMUL_VV: + case RISCV::VSMUL_VX: // Vector Single-Width Scaling Shift Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VSSRL_VI: case RISCV::VSSRL_VV: case RISCV::VSSRL_VX: @@ -426,13 +420,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Permutation Instructions // Integer Scalar Move Instructions // Floating-Point Scalar Move Instructions - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VMV_X_S: case RISCV::VMV_S_X: case RISCV::VFMV_F_S: case RISCV::VFMV_S_F: // Vector Slide Instructions - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VSLIDEUP_VI: case RISCV::VSLIDEUP_VX: case RISCV::VSLIDEDOWN_VI: @@ -442,19 +436,62 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VSLIDE1DOWN_VX: case RISCV::VFSLIDE1DOWN_VF: // Vector Register Gather Instructions - // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1. + // EEW=SEW. For mask operand, EEW=1. case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: // Vector Compress Instruction - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VCOMPRESS_VM: // Vector Element Index Instruction case RISCV::VID_V: - return OperandInfo(MIVLMul, MILog2SEW); + // Vector Single-Width Floating-Point Add/Subtract Instructions + case RISCV::VFADD_VF: + case RISCV::VFADD_VV: + case RISCV::VFSUB_VF: + case RISCV::VFSUB_VV: + case RISCV::VFRSUB_VF: + // Vector Single-Width Floating-Point Multiply/Divide Instructions + case RISCV::VFMUL_VF: + case RISCV::VFMUL_VV: + case RISCV::VFDIV_VF: + case RISCV::VFDIV_VV: + case RISCV::VFRDIV_VF: + // Vector Floating-Point Square-Root Instruction + case RISCV::VFSQRT_V: + // Vector Floating-Point Reciprocal Square-Root Estimate Instruction + case RISCV::VFRSQRT7_V: + // Vector Floating-Point Reciprocal Estimate Instruction + case RISCV::VFREC7_V: + // Vector Floating-Point MIN/MAX Instructions + case RISCV::VFMIN_VF: + case RISCV::VFMIN_VV: + case RISCV::VFMAX_VF: + case RISCV::VFMAX_VV: + // Vector Floating-Point Sign-Injection Instructions + case RISCV::VFSGNJ_VF: + case RISCV::VFSGNJ_VV: + case RISCV::VFSGNJN_VV: + case RISCV::VFSGNJN_VF: + case RISCV::VFSGNJX_VF: + case RISCV::VFSGNJX_VV: + // Vector Floating-Point Classify Instruction + case RISCV::VFCLASS_V: + // Vector Floating-Point Move Instruction + case RISCV::VFMV_V_F: + // Single-Width Floating-Point/Integer Type-Convert Instructions + case RISCV::VFCVT_XU_F_V: + case RISCV::VFCVT_X_F_V: + case RISCV::VFCVT_RTZ_XU_F_V: + case RISCV::VFCVT_RTZ_X_F_V: + case RISCV::VFCVT_F_XU_V: + case RISCV::VFCVT_F_X_V: + // Vector Floating-Point Merge Instruction + case RISCV::VFMERGE_VFM: + return MILog2SEW; // Vector Widening Integer Add/Subtract - // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL. + // Def uses EEW=2*SEW . Operands use EEW=SEW. case RISCV::VWADDU_VV: case RISCV::VWADDU_VX: case RISCV::VWSUBU_VV: @@ -465,7 +502,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWSUB_VX: case RISCV::VWSLL_VI: // Vector Widening Integer Multiply Instructions - // Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW. + // Destination EEW=2*SEW. Source EEW=SEW. case RISCV::VWMUL_VV: case RISCV::VWMUL_VX: case RISCV::VWMULSU_VV: @@ -473,7 +510,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWMULU_VV: case RISCV::VWMULU_VX: // Vector Widening Integer Multiply-Add Instructions - // Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL. + // Destination EEW=2*SEW. Source EEW=SEW. // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which // is then added to the 2*SEW-bit Dest. These instructions never have a // passthru operand. @@ -483,14 +520,38 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWMACC_VX: case RISCV::VWMACCSU_VV: case RISCV::VWMACCSU_VX: - case RISCV::VWMACCUS_VX: { + case RISCV::VWMACCUS_VX: + // Vector Widening Floating-Point Fused Multiply-Add Instructions + case RISCV::VFWMACC_VF: + case RISCV::VFWMACC_VV: + case RISCV::VFWNMACC_VF: + case RISCV::VFWNMACC_VV: + case RISCV::VFWMSAC_VF: + case RISCV::VFWMSAC_VV: + case RISCV::VFWNMSAC_VF: + case RISCV::VFWNMSAC_VV: + // Vector Widening Floating-Point Add/Subtract Instructions + // Dest EEW=2*SEW. Source EEW=SEW. + case RISCV::VFWADD_VV: + case RISCV::VFWADD_VF: + case RISCV::VFWSUB_VV: + case RISCV::VFWSUB_VF: + // Vector Widening Floating-Point Multiply + case RISCV::VFWMUL_VF: + case RISCV::VFWMUL_VV: + // Widening Floating-Point/Integer Type-Convert Instructions + case RISCV::VFWCVT_XU_F_V: + case RISCV::VFWCVT_X_F_V: + case RISCV::VFWCVT_RTZ_XU_F_V: + case RISCV::VFWCVT_RTZ_X_F_V: + case RISCV::VFWCVT_F_XU_V: + case RISCV::VFWCVT_F_X_V: + case RISCV::VFWCVT_F_F_V: { unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - IsMODef ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } - // Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL + // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW. case RISCV::VWADDU_WV: case RISCV::VWADDU_WX: case RISCV::VWSUBU_WV: @@ -498,29 +559,31 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWADD_WV: case RISCV::VWADD_WX: case RISCV::VWSUB_WV: - case RISCV::VWSUB_WX: { + case RISCV::VWSUB_WX: + // Vector Widening Floating-Point Add/Subtract Instructions + case RISCV::VFWADD_WF: + case RISCV::VFWADD_WV: + case RISCV::VFWSUB_WF: + case RISCV::VFWSUB_WV: { bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsMODef || IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } // Vector Integer Extension case RISCV::VZEXT_VF2: case RISCV::VSEXT_VF2: - return getIntegerExtensionOperandInfo(2, MI, MO); + return getIntegerExtensionOperandEEW(2, MI, MO); case RISCV::VZEXT_VF4: case RISCV::VSEXT_VF4: - return getIntegerExtensionOperandInfo(4, MI, MO); + return getIntegerExtensionOperandEEW(4, MI, MO); case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: - return getIntegerExtensionOperandInfo(8, MI, MO); + return getIntegerExtensionOperandEEW(8, MI, MO); // Vector Narrowing Integer Right Shift Instructions - // Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has - // EEW=SEW EMUL=LMUL. + // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW case RISCV::VNSRL_WX: case RISCV::VNSRL_WI: case RISCV::VNSRL_WV: @@ -528,19 +591,26 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VNSRA_WV: case RISCV::VNSRA_WX: // Vector Narrowing Fixed-Point Clip Instructions - // Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL + // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW. case RISCV::VNCLIPU_WI: case RISCV::VNCLIPU_WV: case RISCV::VNCLIPU_WX: case RISCV::VNCLIP_WI: case RISCV::VNCLIP_WV: - case RISCV::VNCLIP_WX: { + case RISCV::VNCLIP_WX: + // Narrowing Floating-Point/Integer Type-Convert Instructions + case RISCV::VFNCVT_XU_F_W: + case RISCV::VFNCVT_X_F_W: + case RISCV::VFNCVT_RTZ_XU_F_W: + case RISCV::VFNCVT_RTZ_X_F_W: + case RISCV::VFNCVT_F_XU_W: + case RISCV::VFNCVT_F_X_W: + case RISCV::VFNCVT_F_F_W: + case RISCV::VFNCVT_ROD_F_F_W: { bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } // Vector Mask Instructions @@ -548,7 +618,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // vmsbf.m set-before-first mask bit // vmsif.m set-including-first mask bit // vmsof.m set-only-first mask bit - // EEW=1 and EMUL=(EEW/SEW)*LMUL + // EEW=1 // We handle the cases when operand is a v0 mask operand above the switch, // but these instructions may use non-v0 mask operands and need to be handled // specifically. @@ -563,20 +633,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMSBF_M: case RISCV::VMSIF_M: case RISCV::VMSOF_M: { - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; } // Vector Iota Instruction - // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= - // (EEW/SEW)*LMUL. Mask operand is not handled before this switch. + // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled + // before this switch. case RISCV::VIOTA_M: { if (IsMODef || MO.getOperandNo() == 1) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return MILog2SEW; + return 0; } // Vector Integer Compare Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW. case RISCV::VMSEQ_VI: case RISCV::VMSEQ_VV: case RISCV::VMSEQ_VX: @@ -598,29 +668,87 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMSGT_VI: case RISCV::VMSGT_VX: // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. Mask - // source operand handled above this switch. + // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch. case RISCV::VMADC_VIM: case RISCV::VMADC_VVM: case RISCV::VMADC_VXM: case RISCV::VMSBC_VVM: case RISCV::VMSBC_VXM: - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW. case RISCV::VMADC_VV: case RISCV::VMADC_VI: case RISCV::VMADC_VX: case RISCV::VMSBC_VV: - case RISCV::VMSBC_VX: { + case RISCV::VMSBC_VX: + // 13.13. Vector Floating-Point Compare Instructions + // Dest EEW=1. Source EEW=SEW + case RISCV::VMFEQ_VF: + case RISCV::VMFEQ_VV: + case RISCV::VMFNE_VF: + case RISCV::VMFNE_VV: + case RISCV::VMFLT_VF: + case RISCV::VMFLT_VV: + case RISCV::VMFLE_VF: + case RISCV::VMFLE_VV: + case RISCV::VMFGT_VF: + case RISCV::VMFGE_VF: { if (IsMODef) - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); - return OperandInfo(MIVLMul, MILog2SEW); + return 0; + return MILog2SEW; + } + + // Vector Reduction Operations + // Vector Single-Width Integer Reduction Instructions + case RISCV::VREDAND_VS: + case RISCV::VREDMAX_VS: + case RISCV::VREDMAXU_VS: + case RISCV::VREDMIN_VS: + case RISCV::VREDMINU_VS: + case RISCV::VREDOR_VS: + case RISCV::VREDSUM_VS: + case RISCV::VREDXOR_VS: { + return MILog2SEW; } default: - return {}; + return std::nullopt; } } +static std::optional<OperandInfo> +getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) { + const MachineInstr &MI = *MO.getParent(); + const RISCVVPseudosTable::PseudoInfo *RVV = + RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); + assert(RVV && "Could not find MI in PseudoTable"); + + std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI); + if (!Log2EEW) + return std::nullopt; + + switch (RVV->BaseInstr) { + // Vector Reduction Operations + // Vector Single-Width Integer Reduction Instructions + // The Dest and VS1 only read element 0 of the vector register. Return just + // the EEW for these. + case RISCV::VREDAND_VS: + case RISCV::VREDMAX_VS: + case RISCV::VREDMAXU_VS: + case RISCV::VREDMIN_VS: + case RISCV::VREDMINU_VS: + case RISCV::VREDOR_VS: + case RISCV::VREDSUM_VS: + case RISCV::VREDXOR_VS: + if (MO.getOperandNo() != 2) + return OperandInfo(*Log2EEW); + break; + }; + + // All others have EMUL=EEW/SEW*LMUL + return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), + *Log2EEW); +} + /// Return true if this optimization should consider MI for VL reduction. This /// white-list approach simplifies this optimization for instructions that may /// have more complex semantics with relation to how it uses VL. @@ -632,6 +760,32 @@ static bool isSupportedInstr(const MachineInstr &MI) { return false; switch (RVV->BaseInstr) { + // Vector Unit-Stride Instructions + // Vector Strided Instructions + case RISCV::VLM_V: + case RISCV::VLE8_V: + case RISCV::VLSE8_V: + case RISCV::VLE16_V: + case RISCV::VLSE16_V: + case RISCV::VLE32_V: + case RISCV::VLSE32_V: + case RISCV::VLE64_V: + case RISCV::VLSE64_V: + // Vector Indexed Instructions + case RISCV::VLUXEI8_V: + case RISCV::VLOXEI8_V: + case RISCV::VLUXEI16_V: + case RISCV::VLOXEI16_V: + case RISCV::VLUXEI32_V: + case RISCV::VLOXEI32_V: + case RISCV::VLUXEI64_V: + case RISCV::VLOXEI64_V: { + for (const MachineMemOperand *MMO : MI.memoperands()) + if (MMO->isVolatile()) + return false; + return true; + } + // Vector Single-Width Integer Add and Subtract case RISCV::VADD_VI: case RISCV::VADD_VV: @@ -801,6 +955,30 @@ static bool isSupportedInstr(const MachineInstr &MI) { case RISCV::VMSOF_M: case RISCV::VIOTA_M: case RISCV::VID_V: + // Single-Width Floating-Point/Integer Type-Convert Instructions + case RISCV::VFCVT_XU_F_V: + case RISCV::VFCVT_X_F_V: + case RISCV::VFCVT_RTZ_XU_F_V: + case RISCV::VFCVT_RTZ_X_F_V: + case RISCV::VFCVT_F_XU_V: + case RISCV::VFCVT_F_X_V: + // Widening Floating-Point/Integer Type-Convert Instructions + case RISCV::VFWCVT_XU_F_V: + case RISCV::VFWCVT_X_F_V: + case RISCV::VFWCVT_RTZ_XU_F_V: + case RISCV::VFWCVT_RTZ_X_F_V: + case RISCV::VFWCVT_F_XU_V: + case RISCV::VFWCVT_F_X_V: + case RISCV::VFWCVT_F_F_V: + // Narrowing Floating-Point/Integer Type-Convert Instructions + case RISCV::VFNCVT_XU_F_W: + case RISCV::VFNCVT_X_F_W: + case RISCV::VFNCVT_RTZ_XU_F_W: + case RISCV::VFNCVT_RTZ_X_F_W: + case RISCV::VFNCVT_F_XU_W: + case RISCV::VFNCVT_F_X_W: + case RISCV::VFNCVT_F_F_W: + case RISCV::VFNCVT_ROD_F_F_W: return true; } @@ -835,6 +1013,9 @@ static bool isVectorOpUsedAsScalarOp(MachineOperand &MO) { case RISCV::VFWREDOSUM_VS: case RISCV::VFWREDUSUM_VS: return MO.getOperandNo() == 3; + case RISCV::VMV_X_S: + case RISCV::VFMV_F_S: + return MO.getOperandNo() == 1; default: return false; } @@ -904,6 +1085,11 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return false; } + if (MI.mayRaiseFPException()) { + LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n"); + return false; + } + // Some instructions that produce vectors have semantics that make it more // difficult to determine whether the VL can be reduced. For example, some // instructions, such as reductions, may write lanes past VL to a scalar @@ -925,79 +1111,103 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return true; } -bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL, - MachineInstr &MI) { +std::optional<MachineOperand> +RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) { + const MachineInstr &UserMI = *UserOp.getParent(); + const MCInstrDesc &Desc = UserMI.getDesc(); + + if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { + LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" + " use VLMAX\n"); + return std::nullopt; + } + + // Instructions like reductions may use a vector register as a scalar + // register. In this case, we should treat it as only reading the first lane. + if (isVectorOpUsedAsScalarOp(UserOp)) { + [[maybe_unused]] Register R = UserOp.getReg(); + [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); + assert(RISCV::VRRegClass.hasSubClassEq(RC) && + "Expect LMUL 1 register class for vector as scalar operands!"); + LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n"); + + return MachineOperand::CreateImm(1); + } + + unsigned VLOpNum = RISCVII::getVLOpNum(Desc); + const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); + // Looking for an immediate or a register VL that isn't X0. + assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && + "Did not expect X0 VL"); + return VLOp; +} + +std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) { // FIXME: Avoid visiting each user for each time we visit something on the // worklist, combined with an extra visit from the outer loop. Restructure // along lines of an instcombine style worklist which integrates the outer // pass. - bool CanReduceVL = true; + std::optional<MachineOperand> CommonVL; for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) { const MachineInstr &UserMI = *UserOp.getParent(); LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); - - // Instructions like reductions may use a vector register as a scalar - // register. In this case, we should treat it like a scalar register which - // does not impact the decision on whether to optimize VL. - // TODO: Treat it like a scalar register instead of bailing out. - if (isVectorOpUsedAsScalarOp(UserOp)) { - CanReduceVL = false; - break; - } - if (mayReadPastVL(UserMI)) { LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); - CanReduceVL = false; - break; + return std::nullopt; } // Tied operands might pass through. if (UserOp.isTied()) { LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n"); - CanReduceVL = false; - break; - } - - const MCInstrDesc &Desc = UserMI.getDesc(); - if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { - LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that" - " use VLMAX\n"); - CanReduceVL = false; - break; + return std::nullopt; } - unsigned VLOpNum = RISCVII::getVLOpNum(Desc); - const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); - - // Looking for an immediate or a register VL that isn't X0. - assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && - "Did not expect X0 VL"); + auto VLOp = getMinimumVLForUser(UserOp); + if (!VLOp) + return std::nullopt; // Use the largest VL among all the users. If we cannot determine this // statically, then we cannot optimize the VL. - if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, VLOp)) { - CommonVL = &VLOp; + if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { + CommonVL = *VLOp; LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); - } else if (!RISCV::isVLKnownLE(VLOp, *CommonVL)) { + } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); - CanReduceVL = false; - break; + return std::nullopt; + } + + if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { + LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); + return std::nullopt; + } + + std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI); + std::optional<OperandInfo> ProducerInfo = + getOperandInfo(MI.getOperand(0), MRI); + if (!ConsumerInfo || !ProducerInfo) { + LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); + LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); + LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); + return std::nullopt; } - // The SEW and LMUL of destination and source registers need to match. - OperandInfo ConsumerInfo = getOperandInfo(UserOp, MRI); - OperandInfo ProducerInfo = getOperandInfo(MI.getOperand(0), MRI); - if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() || - !OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) { - LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown " - "information for EMUL or EEW.\n"); + // If the operand is used as a scalar operand, then the EEW must be + // compatible. Otherwise, the EMUL *and* EEW must be compatible. + bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp); + if ((IsVectorOpUsedAsScalarOp && + !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) || + (!IsVectorOpUsedAsScalarOp && + !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) { + LLVM_DEBUG( + dbgs() + << " Abort due to incompatible information for EMUL or EEW.\n"); LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); - CanReduceVL = false; - break; + return std::nullopt; } } - return CanReduceVL; + + return CommonVL; } bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { @@ -1009,12 +1219,11 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { MachineInstr &MI = *Worklist.pop_back_val(); LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); - const MachineOperand *CommonVL = nullptr; - bool CanReduceVL = true; - if (isVectorRegClass(MI.getOperand(0).getReg(), MRI)) - CanReduceVL = checkUsers(CommonVL, MI); + if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI)) + continue; - if (!CanReduceVL || !CommonVL) + auto CommonVL = checkUsers(MI); + if (!CommonVL) continue; assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index aa83d99..a79e19f 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -20,7 +20,6 @@ add_llvm_target(SPIRVCodeGen SPIRVCallLowering.cpp SPIRVInlineAsmLowering.cpp SPIRVCommandLine.cpp - SPIRVDuplicatesTracker.cpp SPIRVEmitIntrinsics.cpp SPIRVGlobalRegistry.cpp SPIRVInstrInfo.cpp diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 4012bd7..78add92 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -274,7 +274,7 @@ void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) { } void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) { - for (MachineInstr *MI : MAI->getMSInstrs(MSType)) + for (const MachineInstr *MI : MAI->getMSInstrs(MSType)) outputInstruction(MI); } @@ -326,7 +326,7 @@ void SPIRVAsmPrinter::outputOpMemoryModel() { void SPIRVAsmPrinter::outputEntryPoints() { // Find all OpVariable IDs with required StorageClass. DenseSet<Register> InterfaceIDs; - for (MachineInstr *MI : MAI->GlobalVarList) { + for (const MachineInstr *MI : MAI->GlobalVarList) { assert(MI->getOpcode() == SPIRV::OpVariable); auto SC = static_cast<SPIRV::StorageClass::StorageClass>( MI->getOperand(2).getImm()); @@ -336,14 +336,14 @@ void SPIRVAsmPrinter::outputEntryPoints() { // declaring all global variables referenced by the entry point call tree. if (ST->isAtLeastSPIRVVer(VersionTuple(1, 4)) || SC == SPIRV::StorageClass::Input || SC == SPIRV::StorageClass::Output) { - MachineFunction *MF = MI->getMF(); + const MachineFunction *MF = MI->getMF(); Register Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); InterfaceIDs.insert(Reg); } } // Output OpEntryPoints adding interface args to all of them. - for (MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) { + for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) { SPIRVMCInstLower MCInstLowering; MCInst TmpInst; MCInstLowering.lower(MI, TmpInst, MAI); @@ -381,9 +381,8 @@ void SPIRVAsmPrinter::outputGlobalRequirements() { void SPIRVAsmPrinter::outputExtFuncDecls() { // Insert OpFunctionEnd after each declaration. - SmallVectorImpl<MachineInstr *>::iterator - I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(), - E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end(); + auto I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(), + E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end(); for (; I != E; ++I) { outputInstruction(*I); if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction) diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index fa37313f..44b6f5f8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -418,6 +418,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, .addImm(FuncControl) .addUse(GR->getSPIRVTypeID(FuncTy)); GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0)); + GR->addGlobalObject(&F, &MIRBuilder.getMF(), FuncVReg); // Add OpFunctionParameter instructions int i = 0; @@ -431,6 +432,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); if (F.isDeclaration()) GR->add(&Arg, &MIRBuilder.getMF(), ArgReg); + GR->addGlobalObject(&Arg, &MIRBuilder.getMF(), ArgReg); i++; } // Name the function. diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp deleted file mode 100644 index 48df845..0000000 --- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp +++ /dev/null @@ -1,136 +0,0 @@ -//===-- SPIRVDuplicatesTracker.cpp - SPIR-V Duplicates Tracker --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// General infrastructure for keeping track of the values that according to -// the SPIR-V binary layout should be global to the whole module. -// -//===----------------------------------------------------------------------===// - -#include "SPIRVDuplicatesTracker.h" -#include "SPIRVInstrInfo.h" - -#define DEBUG_TYPE "build-dep-graph" - -using namespace llvm; - -template <typename T> -void SPIRVGeneralDuplicatesTracker::prebuildReg2Entry( - SPIRVDuplicatesTracker<T> &DT, SPIRVReg2EntryTy &Reg2Entry, - const SPIRVInstrInfo *TII) { - for (auto &TPair : DT.getAllUses()) { - for (auto &RegPair : TPair.second) { - const MachineFunction *MF = RegPair.first; - Register R = RegPair.second; - MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(R); - if (!MI || (TPair.second.getIsConst() && !TII->isConstantInstr(*MI))) - continue; - Reg2Entry[&MI->getOperand(0)] = &TPair.second; - } - } -} - -void SPIRVGeneralDuplicatesTracker::buildDepsGraph( - std::vector<SPIRV::DTSortableEntry *> &Graph, const SPIRVInstrInfo *TII, - MachineModuleInfo *MMI = nullptr) { - SPIRVReg2EntryTy Reg2Entry; - prebuildReg2Entry(TT, Reg2Entry, TII); - prebuildReg2Entry(CT, Reg2Entry, TII); - prebuildReg2Entry(GT, Reg2Entry, TII); - prebuildReg2Entry(FT, Reg2Entry, TII); - prebuildReg2Entry(AT, Reg2Entry, TII); - prebuildReg2Entry(MT, Reg2Entry, TII); - prebuildReg2Entry(ST, Reg2Entry, TII); - - for (auto &Op2E : Reg2Entry) { - SPIRV::DTSortableEntry *E = Op2E.second; - Graph.push_back(E); - for (auto &U : *E) { - const MachineRegisterInfo &MRI = U.first->getRegInfo(); - MachineInstr *MI = MRI.getUniqueVRegDef(U.second); - if (!MI) - continue; - assert(MI && MI->getParent() && "No MachineInstr created yet"); - for (auto i = MI->getNumDefs(); i < MI->getNumOperands(); i++) { - MachineOperand &Op = MI->getOperand(i); - if (!Op.isReg()) - continue; - MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg()); - // References to a function via function pointers generate virtual - // registers without a definition. We are able to resolve this - // reference using Globar Register info into an OpFunction instruction - // but do not expect to find it in Reg2Entry. - if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2) - continue; - MachineOperand *RegOp = &VRegDef->getOperand(0); - if (Reg2Entry.count(RegOp) == 0 && - (MI->getOpcode() != SPIRV::OpVariable || i != 3)) { - // try to repair the unexpected code pattern - bool IsFixed = false; - if (VRegDef->getOpcode() == TargetOpcode::G_CONSTANT && - RegOp->isReg() && MRI.getType(RegOp->getReg()).isScalar()) { - const Constant *C = VRegDef->getOperand(1).getCImm(); - add(C, MI->getParent()->getParent(), RegOp->getReg()); - auto Iter = CT.Storage.find(C); - if (Iter != CT.Storage.end()) { - SPIRV::DTSortableEntry &MissedEntry = Iter->second; - Reg2Entry[RegOp] = &MissedEntry; - IsFixed = true; - } - } - if (!IsFixed) { - std::string DiagMsg; - raw_string_ostream OS(DiagMsg); - OS << "Unexpected pattern while building a dependency " - "graph.\nInstruction: "; - MI->print(OS); - OS << "Operand: "; - Op.print(OS); - OS << "\nOperand definition: "; - VRegDef->print(OS); - report_fatal_error(DiagMsg.c_str()); - } - } - if (Reg2Entry.count(RegOp)) - E->addDep(Reg2Entry[RegOp]); - } - - if (E->getIsFunc()) { - MachineInstr *Next = MI->getNextNode(); - if (Next && (Next->getOpcode() == SPIRV::OpFunction || - Next->getOpcode() == SPIRV::OpFunctionParameter)) { - E->addDep(Reg2Entry[&Next->getOperand(0)]); - } - } - } - } - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) - if (MMI) { - const Module *M = MMI->getModule(); - for (auto F = M->begin(), E = M->end(); F != E; ++F) { - const MachineFunction *MF = MMI->getMachineFunction(*F); - if (!MF) - continue; - for (const MachineBasicBlock &MBB : *MF) { - for (const MachineInstr &CMI : MBB) { - MachineInstr &MI = const_cast<MachineInstr &>(CMI); - MI.dump(); - if (MI.getNumExplicitDefs() > 0 && - Reg2Entry.count(&MI.getOperand(0))) { - dbgs() << "\t["; - for (SPIRV::DTSortableEntry *D : - Reg2Entry.lookup(&MI.getOperand(0))->getDeps()) - dbgs() << Register::virtReg2Index(D->lookup(MF)) << ", "; - dbgs() << "]\n"; - } - } - } - } - } -#endif -} diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h index 6847da0..e574892 100644 --- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h +++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h @@ -211,23 +211,7 @@ class SPIRVGeneralDuplicatesTracker { SPIRVDuplicatesTracker<MachineInstr> MT; SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST; - // NOTE: using MOs instead of regs to get rid of MF dependency to be able - // to use flat data structure. - // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness - // but makes LITs more stable, should prefer DenseMap still due to - // significant perf difference. - using SPIRVReg2EntryTy = - MapVector<MachineOperand *, SPIRV::DTSortableEntry *>; - - template <typename T> - void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT, - SPIRVReg2EntryTy &Reg2Entry, - const SPIRVInstrInfo *TII); - public: - void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph, - const SPIRVInstrInfo *TII, MachineModuleInfo *MMI); - void add(const Type *Ty, const MachineFunction *MF, Register R) { TT.add(unifyPtrType(Ty), MF, R); } diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 77b5421..d2b14d6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -1841,20 +1841,20 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV, // Skip special artifical variable llvm.global.annotations. if (GV.getName() == "llvm.global.annotations") return; - if (GV.hasInitializer() && !isa<UndefValue>(GV.getInitializer())) { + Constant *Init = nullptr; + if (hasInitializer(&GV)) { // Deduce element type and store results in Global Registry. // Result is ignored, because TypedPointerType is not supported // by llvm IR general logic. deduceElementTypeHelper(&GV, false); - Constant *Init = GV.getInitializer(); + Init = GV.getInitializer(); Type *Ty = isAggrConstForceInt32(Init) ? B.getInt32Ty() : Init->getType(); Constant *Const = isAggrConstForceInt32(Init) ? B.getInt32(1) : Init; auto *InitInst = B.CreateIntrinsic(Intrinsic::spv_init_global, {GV.getType(), Ty}, {&GV, Const}); InitInst->setArgOperand(1, Init); } - if ((!GV.hasInitializer() || isa<UndefValue>(GV.getInitializer())) && - GV.getNumUses() == 0) + if (!Init && GV.getNumUses() == 0) B.CreateIntrinsic(Intrinsic::spv_unref_global, GV.getType(), &GV); } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 0c424477..a06c62e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -721,6 +721,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable( } Reg = MIB->getOperand(0).getReg(); DT.add(GVar, &MIRBuilder.getMF(), Reg); + addGlobalObject(GVar, &MIRBuilder.getMF(), Reg); // Set to Reg the same type as ResVReg has. auto MRI = MIRBuilder.getMRI(); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index ec2386fa..528baf5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -89,6 +89,9 @@ class SPIRVGlobalRegistry { // Intrinsic::spv_assign_ptr_type instructions. DenseMap<Value *, CallInst *> AssignPtrTypeInstr; + // Maps OpVariable and OpFunction-related v-regs to its LLVM IR definition. + DenseMap<std::pair<const MachineFunction *, Register>, const Value *> Reg2GO; + // Add a new OpTypeXXX instruction without checking for duplicates. SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ = @@ -151,15 +154,17 @@ public: return DT.find(F, MF); } - void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph, - const SPIRVInstrInfo *TII, - MachineModuleInfo *MMI = nullptr) { - DT.buildDepsGraph(Graph, TII, MMI); - } - void setBound(unsigned V) { Bound = V; } unsigned getBound() { return Bound; } + void addGlobalObject(const Value *V, const MachineFunction *MF, Register R) { + Reg2GO[std::make_pair(MF, R)] = V; + } + const Value *getGlobalObject(const MachineFunction *MF, Register R) { + auto It = Reg2GO.find(std::make_pair(MF, R)); + return It == Reg2GO.end() ? nullptr : It->second; + } + // Add a record to the map of function return pointer types. void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) { FunResPointerTypes[ArgF] = DerivedTy; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index bd9e77e..9a140e7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -47,6 +47,19 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const { } } +bool SPIRVInstrInfo::isSpecConstantInstr(const MachineInstr &MI) const { + switch (MI.getOpcode()) { + case SPIRV::OpSpecConstantTrue: + case SPIRV::OpSpecConstantFalse: + case SPIRV::OpSpecConstant: + case SPIRV::OpSpecConstantComposite: + case SPIRV::OpSpecConstantOp: + return true; + default: + return false; + } +} + bool SPIRVInstrInfo::isInlineAsmDefInstr(const MachineInstr &MI) const { switch (MI.getOpcode()) { case SPIRV::OpAsmTargetINTEL: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h index 67d2d97..4e5059b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h @@ -30,6 +30,7 @@ public: const SPIRVRegisterInfo &getRegisterInfo() const { return RI; } bool isHeaderInstr(const MachineInstr &MI) const; bool isConstantInstr(const MachineInstr &MI) const; + bool isSpecConstantInstr(const MachineInstr &MI) const; bool isInlineAsmDefInstr(const MachineInstr &MI) const; bool isTypeDeclInstr(const MachineInstr &MI) const; bool isDecorationInstr(const MachineInstr &MI) const; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 289d5f3..28c9b81 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1105,6 +1105,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg, Constant::getNullValue(LLVMArrTy)); Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); GR.add(GV, GR.CurMF, VarReg); + GR.addGlobalObject(GV, GR.CurMF, VarReg); Result &= BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable)) @@ -2881,6 +2882,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, // translated to a `LocalInvocationId` builtin variable return loadVec3BuiltinInputID(SPIRV::BuiltIn::LocalInvocationId, ResVReg, ResType, I); + case Intrinsic::spv_group_id: + // The HLSL SV_GroupId semantic is lowered to + // llvm.spv.group.id intrinsic in LLVM IR for SPIR-V backend. + // + // In SPIR-V backend, llvm.spv.group.id is now translated to a `WorkgroupId` + // builtin variable + return loadVec3BuiltinInputID(SPIRV::BuiltIn::WorkgroupId, ResVReg, ResType, + I); case Intrinsic::spv_fdot: return selectFloatDot(ResVReg, ResType, I); case Intrinsic::spv_udot: @@ -2906,6 +2915,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectAny(ResVReg, ResType, I); case Intrinsic::spv_cross: return selectExtInst(ResVReg, ResType, I, CL::cross, GL::Cross); + case Intrinsic::spv_distance: + return selectExtInst(ResVReg, ResType, I, CL::distance, GL::Distance); case Intrinsic::spv_lerp: return selectExtInst(ResVReg, ResType, I, CL::mix, GL::FMix); case Intrinsic::spv_length: @@ -3450,7 +3461,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( ID = UnnamedGlobalIDs.size(); GlobalIdent = "__unnamed_" + Twine(ID).str(); } else { - GlobalIdent = GV->getGlobalIdentifier(); + GlobalIdent = GV->getName(); } // Behaviour of functions as operands depends on availability of the @@ -3482,18 +3493,25 @@ bool SPIRVInstructionSelector::selectGlobalValue( // References to a function via function pointers generate virtual // registers without a definition. We will resolve it later, during // module analysis stage. + Register ResTypeReg = GR.getSPIRVTypeID(ResType); MachineRegisterInfo *MRI = MIRBuilder.getMRI(); - Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); - MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass); - MachineInstrBuilder MB = + Register FuncVReg = + MRI->createGenericVirtualRegister(GR.getRegType(ResType)); + MRI->setRegClass(FuncVReg, &SPIRV::pIDRegClass); + MachineInstrBuilder MIB1 = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) + .addDef(FuncVReg) + .addUse(ResTypeReg); + MachineInstrBuilder MIB2 = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantFunctionPointerINTEL)) .addDef(NewReg) - .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(ResTypeReg) .addUse(FuncVReg); // mapping the function pointer to the used Function - GR.recordFunctionPointer(&MB.getInstr()->getOperand(2), GVFun); - return MB.constrainAllUses(TII, TRI, RBI); + GR.recordFunctionPointer(&MIB2.getInstr()->getOperand(2), GVFun); + return MIB1.constrainAllUses(TII, TRI, RBI) && + MIB2.constrainAllUses(TII, TRI, RBI); } return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) .addDef(NewReg) @@ -3506,18 +3524,16 @@ bool SPIRVInstructionSelector::selectGlobalValue( auto GlobalVar = cast<GlobalVariable>(GV); assert(GlobalVar->getName() != "llvm.global.annotations"); - bool HasInit = GlobalVar->hasInitializer() && - !isa<UndefValue>(GlobalVar->getInitializer()); - // Skip empty declaration for GVs with initilaizers till we get the decl with + // Skip empty declaration for GVs with initializers till we get the decl with // passed initializer. - if (HasInit && !Init) + if (hasInitializer(GlobalVar) && !Init) return true; - bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage; + bool HasLnkTy = !GV->hasInternalLinkage() && !GV->hasPrivateLinkage(); SPIRV::LinkageType::LinkageType LnkType = - (GV->isDeclaration() || GV->hasAvailableExternallyLinkage()) + GV->isDeclarationForLinker() ? SPIRV::LinkageType::Import - : (GV->getLinkage() == GlobalValue::LinkOnceODRLinkage && + : (GV->hasLinkOnceODRLinkage() && STI.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr) ? SPIRV::LinkageType::LinkOnceODR : SPIRV::LinkageType::Export); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 6371c67..63adf54 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -216,102 +216,262 @@ void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { } } -// Collect MI which defines the register in the given machine function. -static void collectDefInstr(Register Reg, const MachineFunction *MF, - SPIRV::ModuleAnalysisInfo *MAI, - SPIRV::ModuleSectionType MSType, - bool DoInsert = true) { - assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); - MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); - assert(MI && "There should be an instruction that defines the register"); - MAI->setSkipEmission(MI); - if (DoInsert) - MAI->MS[MSType].push_back(MI); +// Returns a representation of an instruction as a vector of MachineOperand +// hash values, see llvm::hash_value(const MachineOperand &MO) for details. +// This creates a signature of the instruction with the same content +// that MachineOperand::isIdenticalTo uses for comparison. +static InstrSignature instrToSignature(const MachineInstr &MI, + SPIRV::ModuleAnalysisInfo &MAI, + bool UseDefReg) { + InstrSignature Signature{MI.getOpcode()}; + for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + const MachineOperand &MO = MI.getOperand(i); + size_t h; + if (MO.isReg()) { + if (!UseDefReg && MO.isDef()) + continue; + Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); + if (!RegAlias.isValid()) { + LLVM_DEBUG({ + dbgs() << "Unexpectedly, no global id found for the operand "; + MO.print(dbgs()); + dbgs() << "\nInstruction: "; + MI.print(dbgs()); + dbgs() << "\n"; + }); + report_fatal_error("All v-regs must have been mapped to global id's"); + } + // mimic llvm::hash_value(const MachineOperand &MO) + h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), + MO.isDef()); + } else { + h = hash_value(MO); + } + Signature.push_back(h); + } + return Signature; } -void SPIRVModuleAnalysis::collectGlobalEntities( - const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, - SPIRV::ModuleSectionType MSType, - std::function<bool(const SPIRV::DTSortableEntry *)> Pred, - bool UsePreOrder = false) { - DenseSet<const SPIRV::DTSortableEntry *> Visited; - for (const auto *E : DepsGraph) { - std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; - // NOTE: here we prefer recursive approach over iterative because - // we don't expect depchains long enough to cause SO. - RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, - &RecHoistUtil](const SPIRV::DTSortableEntry *E) { - if (Visited.count(E) || !Pred(E)) - return; - Visited.insert(E); - - // Traversing deps graph in post-order allows us to get rid of - // register aliases preprocessing. - // But pre-order is required for correct processing of function - // declaration and arguments processing. - if (!UsePreOrder) - for (auto *S : E->getDeps()) - RecHoistUtil(S); - - Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); - bool IsFirst = true; - for (auto &U : *E) { - const MachineFunction *MF = U.first; - Register Reg = U.second; - MAI.setRegisterAlias(MF, Reg, GlobalReg); - if (!MF->getRegInfo().getUniqueVRegDef(Reg)) - continue; - collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); - IsFirst = false; - if (E->getIsGV()) - MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); - } +bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI, + const MachineInstr &MI) { + unsigned Opcode = MI.getOpcode(); + switch (Opcode) { + case SPIRV::OpTypeForwardPointer: + // omit now, collect later + return false; + case SPIRV::OpVariable: + return static_cast<SPIRV::StorageClass::StorageClass>( + MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function; + case SPIRV::OpFunction: + case SPIRV::OpFunctionParameter: + return true; + } + if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) { + Register DefReg = MI.getOperand(0).getReg(); + for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) { + if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL) + continue; + // it's a dummy definition, FP constant refers to a function, + // and this is resolved in another way; let's skip this definition + assert(UseMI.getOperand(2).isReg() && + UseMI.getOperand(2).getReg() == DefReg); + MAI.setSkipEmission(&MI); + return false; + } + } + return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) || + TII->isInlineAsmDefInstr(MI); +} - if (UsePreOrder) - for (auto *S : E->getDeps()) - RecHoistUtil(S); - }; - RecHoistUtil(E); +// This is a special case of a function pointer refering to a possibly +// forward function declaration. The operand is a dummy OpUndef that +// requires a special treatment. +void SPIRVModuleAnalysis::visitFunPtrUse( + Register OpReg, InstrGRegsMap &SignatureToGReg, + std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF, + const MachineInstr &MI) { + const MachineOperand *OpFunDef = + GR->getFunctionDefinitionByUse(&MI.getOperand(2)); + assert(OpFunDef && OpFunDef->isReg()); + // find the actual function definition and number it globally in advance + const MachineInstr *OpDefMI = OpFunDef->getParent(); + assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction); + const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent(); + const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo(); + do { + visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI); + OpDefMI = OpDefMI->getNextNode(); + } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction || + OpDefMI->getOpcode() == SPIRV::OpFunctionParameter)); + // associate the function pointer with the newly assigned global number + Register GlobalFunDefReg = MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg()); + assert(GlobalFunDefReg.isValid() && + "Function definition must refer to a global register"); + MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg); +} + +// Depth first recursive traversal of dependencies. Repeated visits are guarded +// by MAI.hasRegisterAlias(). +void SPIRVModuleAnalysis::visitDecl( + const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg, + std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF, + const MachineInstr &MI) { + unsigned Opcode = MI.getOpcode(); + DenseSet<Register> Deps; + + // Process each operand of the instruction to resolve dependencies + for (const MachineOperand &MO : MI.operands()) { + if (!MO.isReg() || MO.isDef()) + continue; + Register OpReg = MO.getReg(); + // Handle function pointers special case + if (Opcode == SPIRV::OpConstantFunctionPointerINTEL && + MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) { + visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI); + continue; + } + // Skip already processed instructions + if (MAI.hasRegisterAlias(MF, MO.getReg())) + continue; + // Recursively visit dependencies + if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) { + if (isDeclSection(MRI, *OpDefMI)) + visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI); + continue; + } + // Handle the unexpected case of no unique definition for the SPIR-V + // instruction + LLVM_DEBUG({ + dbgs() << "Unexpectedly, no unique definition for the operand "; + MO.print(dbgs()); + dbgs() << "\nInstruction: "; + MI.print(dbgs()); + dbgs() << "\n"; + }); + report_fatal_error( + "No unique definition is found for the virtual register"); } + + Register GReg; + bool IsFunDef = false; + if (TII->isSpecConstantInstr(MI)) { + GReg = Register::index2VirtReg(MAI.getNextID()); + MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI); + } else if (Opcode == SPIRV::OpFunction || + Opcode == SPIRV::OpFunctionParameter) { + GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef); + } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) || + TII->isInlineAsmDefInstr(MI)) { + GReg = handleTypeDeclOrConstant(MI, SignatureToGReg); + } else if (Opcode == SPIRV::OpVariable) { + GReg = handleVariable(MF, MI, GlobalToGReg); + } else { + LLVM_DEBUG({ + dbgs() << "\nInstruction: "; + MI.print(dbgs()); + dbgs() << "\n"; + }); + llvm_unreachable("Unexpected instruction is visited"); + } + MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg); + if (!IsFunDef) + MAI.setSkipEmission(&MI); } -// The function initializes global register alias table for types, consts, -// global vars and func decls and collects these instruction for output -// at module level. Also it collects explicit OpExtension/OpCapability -// instructions. -void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { - std::vector<SPIRV::DTSortableEntry *> DepsGraph; +Register SPIRVModuleAnalysis::handleFunctionOrParameter( + const MachineFunction *MF, const MachineInstr &MI, + std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) { + const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg()); + assert(GObj && "Unregistered global definition"); + const Function *F = dyn_cast<Function>(GObj); + if (!F) + F = dyn_cast<Argument>(GObj)->getParent(); + assert(F && "Expected a reference to a function or an argument"); + IsFunDef = !F->isDeclaration(); + auto It = GlobalToGReg.find(GObj); + if (It != GlobalToGReg.end()) + return It->second; + Register GReg = Register::index2VirtReg(MAI.getNextID()); + GlobalToGReg[GObj] = GReg; + if (!IsFunDef) + MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI); + return GReg; +} - GR->buildDepsGraph(DepsGraph, TII, SPVDumpDeps ? MMI : nullptr); +Register +SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI, + InstrGRegsMap &SignatureToGReg) { + InstrSignature MISign = instrToSignature(MI, MAI, false); + auto It = SignatureToGReg.find(MISign); + if (It != SignatureToGReg.end()) + return It->second; + Register GReg = Register::index2VirtReg(MAI.getNextID()); + SignatureToGReg[MISign] = GReg; + MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI); + return GReg; +} - collectGlobalEntities( - DepsGraph, SPIRV::MB_TypeConstVars, - [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); +Register SPIRVModuleAnalysis::handleVariable( + const MachineFunction *MF, const MachineInstr &MI, + std::map<const Value *, unsigned> &GlobalToGReg) { + MAI.GlobalVarList.push_back(&MI); + const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg()); + assert(GObj && "Unregistered global definition"); + auto It = GlobalToGReg.find(GObj); + if (It != GlobalToGReg.end()) + return It->second; + Register GReg = Register::index2VirtReg(MAI.getNextID()); + GlobalToGReg[GObj] = GReg; + MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI); + return GReg; +} +void SPIRVModuleAnalysis::collectDeclarations(const Module &M) { + InstrGRegsMap SignatureToGReg; + std::map<const Value *, unsigned> GlobalToGReg; for (auto F = M.begin(), E = M.end(); F != E; ++F) { MachineFunction *MF = MMI->getMachineFunction(*F); if (!MF) continue; - // Iterate through and collect OpExtension/OpCapability instructions. + const MachineRegisterInfo &MRI = MF->getRegInfo(); + unsigned PastHeader = 0; for (MachineBasicBlock &MBB : *MF) { for (MachineInstr &MI : MBB) { - if (MI.getOpcode() == SPIRV::OpExtension) { - // Here, OpExtension just has a single enum operand, not a string. - auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); - MAI.Reqs.addExtension(Ext); + if (MI.getNumOperands() == 0) + continue; + unsigned Opcode = MI.getOpcode(); + if (Opcode == SPIRV::OpFunction) { + if (PastHeader == 0) { + PastHeader = 1; + continue; + } + } else if (Opcode == SPIRV::OpFunctionParameter) { + if (PastHeader < 2) + continue; + } else if (PastHeader > 0) { + PastHeader = 2; + } + + const MachineOperand &DefMO = MI.getOperand(0); + switch (Opcode) { + case SPIRV::OpExtension: + MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm())); MAI.setSkipEmission(&MI); - } else if (MI.getOpcode() == SPIRV::OpCapability) { - auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); - MAI.Reqs.addCapability(Cap); + break; + case SPIRV::OpCapability: + MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm())); MAI.setSkipEmission(&MI); + if (PastHeader > 0) + PastHeader = 2; + break; + default: + if (DefMO.isReg() && isDeclSection(MRI, MI) && + !MAI.hasRegisterAlias(MF, DefMO.getReg())) + visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI); } } } } - - collectGlobalEntities( - DepsGraph, SPIRV::MB_ExtFuncDecls, - [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); } // Look for IDs declared with Import linkage, and map the corresponding function @@ -342,58 +502,6 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, } } -// References to a function via function pointers generate virtual -// registers without a definition. We are able to resolve this -// reference using Globar Register info into an OpFunction instruction -// and replace dummy operands by the corresponding global register references. -void SPIRVModuleAnalysis::collectFuncPtrs() { - for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars]) - if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL) - collectFuncPtrs(MI); -} - -void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) { - const MachineOperand *FunUse = &MI->getOperand(2); - if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) { - const MachineInstr *FunDefMI = FunDef->getParent(); - assert(FunDefMI->getOpcode() == SPIRV::OpFunction && - "Constant function pointer must refer to function definition"); - Register FunDefReg = FunDef->getReg(); - Register GlobalFunDefReg = - MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg); - assert(GlobalFunDefReg.isValid() && - "Function definition must refer to a global register"); - Register FunPtrReg = FunUse->getReg(); - MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg); - } -} - -using InstrSignature = SmallVector<size_t>; -using InstrTraces = std::set<InstrSignature>; - -// Returns a representation of an instruction as a vector of MachineOperand -// hash values, see llvm::hash_value(const MachineOperand &MO) for details. -// This creates a signature of the instruction with the same content -// that MachineOperand::isIdenticalTo uses for comparison. -static InstrSignature instrToSignature(MachineInstr &MI, - SPIRV::ModuleAnalysisInfo &MAI) { - InstrSignature Signature; - for (unsigned i = 0; i < MI.getNumOperands(); ++i) { - const MachineOperand &MO = MI.getOperand(i); - size_t h; - if (MO.isReg()) { - Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); - // mimic llvm::hash_value(const MachineOperand &MO) - h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), - MO.isDef()); - } else { - h = hash_value(MO); - } - Signature.push_back(h); - } - return Signature; -} - // Collect the given instruction in the specified MS. We assume global register // numbering has already occurred by this point. We can directly compare reg // arguments when detecting duplicates. @@ -401,7 +509,7 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, SPIRV::ModuleSectionType MSType, InstrTraces &IS, bool Append = true) { MAI.setSkipEmission(&MI); - InstrSignature MISign = instrToSignature(MI, MAI); + InstrSignature MISign = instrToSignature(MI, MAI, true); auto FoundMI = IS.insert(MISign); if (!FoundMI.second) return; // insert failed, so we found a duplicate; don't add it to MAI.MS @@ -465,7 +573,7 @@ void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { // Number registers in all functions globally from 0 onwards and store // the result in global register alias table. Some registers are already -// numbered in collectGlobalEntities. +// numbered. void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { for (auto F = M.begin(), E = M.end(); F != E; ++F) { if ((*F).isDeclaration()) @@ -1835,15 +1943,11 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) { // Process type/const/global var/func decl instructions, number their // destination registers from 0 to N, collect Extensions and Capabilities. - processDefInstrs(M); + collectDeclarations(M); // Number rest of registers from N+1 onwards. numberRegistersGlobally(M); - // Update references to OpFunction instructions to use Global Registers - if (GR->hasConstFunPtr()) - collectFuncPtrs(); - // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. processOtherInstrs(M); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h index ee2aaf1..79b5444 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -124,7 +124,7 @@ public: const Capability::Capability IfPresent); }; -using InstrList = SmallVector<MachineInstr *>; +using InstrList = SmallVector<const MachineInstr *>; // Maps a local register to the corresponding global alias. using LocalToGlobalRegTable = std::map<Register, Register>; using RegisterAliasMapTy = @@ -142,12 +142,12 @@ struct ModuleAnalysisInfo { // Maps ExtInstSet to corresponding ID register. DenseMap<unsigned, Register> ExtInstSetMap; // Contains the list of all global OpVariables in the module. - SmallVector<MachineInstr *, 4> GlobalVarList; + SmallVector<const MachineInstr *, 4> GlobalVarList; // Maps functions to corresponding function ID registers. DenseMap<const Function *, Register> FuncMap; // The set contains machine instructions which are necessary // for correct MIR but will not be emitted in function bodies. - DenseSet<MachineInstr *> InstrsToDelete; + DenseSet<const MachineInstr *> InstrsToDelete; // The table contains global aliases of local registers for each machine // function. The aliases are used to substitute local registers during // code emission. @@ -167,7 +167,7 @@ struct ModuleAnalysisInfo { } Register getExtInstSetReg(unsigned SetNum) { return ExtInstSetMap[SetNum]; } InstrList &getMSInstrs(unsigned MSType) { return MS[MSType]; } - void setSkipEmission(MachineInstr *MI) { InstrsToDelete.insert(MI); } + void setSkipEmission(const MachineInstr *MI) { InstrsToDelete.insert(MI); } bool getSkipEmission(const MachineInstr *MI) { return InstrsToDelete.contains(MI); } @@ -204,6 +204,10 @@ struct ModuleAnalysisInfo { }; } // namespace SPIRV +using InstrSignature = SmallVector<size_t>; +using InstrTraces = std::set<InstrSignature>; +using InstrGRegsMap = std::map<SmallVector<size_t>, unsigned>; + struct SPIRVModuleAnalysis : public ModulePass { static char ID; @@ -216,17 +220,27 @@ public: private: void setBaseInfo(const Module &M); - void collectGlobalEntities( - const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, - SPIRV::ModuleSectionType MSType, - std::function<bool(const SPIRV::DTSortableEntry *)> Pred, - bool UsePreOrder); - void processDefInstrs(const Module &M); void collectFuncNames(MachineInstr &MI, const Function *F); void processOtherInstrs(const Module &M); void numberRegistersGlobally(const Module &M); - void collectFuncPtrs(); - void collectFuncPtrs(MachineInstr *MI); + + // analyze dependencies to collect module scope definitions + void collectDeclarations(const Module &M); + void visitDecl(const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg, + std::map<const Value *, unsigned> &GlobalToGReg, + const MachineFunction *MF, const MachineInstr &MI); + Register handleVariable(const MachineFunction *MF, const MachineInstr &MI, + std::map<const Value *, unsigned> &GlobalToGReg); + Register handleTypeDeclOrConstant(const MachineInstr &MI, + InstrGRegsMap &SignatureToGReg); + Register + handleFunctionOrParameter(const MachineFunction *MF, const MachineInstr &MI, + std::map<const Value *, unsigned> &GlobalToGReg, + bool &IsFunDef); + void visitFunPtrUse(Register OpReg, InstrGRegsMap &SignatureToGReg, + std::map<const Value *, unsigned> &GlobalToGReg, + const MachineFunction *MF, const MachineInstr &MI); + bool isDeclSection(const MachineRegisterInfo &MRI, const MachineInstr &MI); const SPIRVSubtarget *ST; SPIRVGlobalRegistry *GR; diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 8357c30..5b4c849 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -58,9 +58,10 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR, ->getValue()); if (auto *GV = dyn_cast<GlobalValue>(Const)) { Register Reg = GR->find(GV, &MF); - if (!Reg.isValid()) + if (!Reg.isValid()) { GR->add(GV, &MF, SrcReg); - else + GR->addGlobalObject(GV, &MF, SrcReg); + } else RegsAlreadyAddedToDT[&MI] = Reg; } else { Register Reg = GR->find(Const, &MF); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index da2e24c..60649ea 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -17,6 +17,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/TypedPointerType.h" #include <queue> @@ -236,6 +237,10 @@ Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx); // Returns true if the function was changed. bool sortBlocks(Function &F); +inline bool hasInitializer(const GlobalVariable *GV) { + return GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer()); +} + // True if this is an instance of TypedPointerType. inline bool isTypedPointerTy(const Type *T) { return T && T->getTypeID() == Type::TypedPointerTyID; diff --git a/llvm/lib/Target/TargetMachine.cpp b/llvm/lib/Target/TargetMachine.cpp index c0985f3..d5365f3 100644 --- a/llvm/lib/Target/TargetMachine.cpp +++ b/llvm/lib/Target/TargetMachine.cpp @@ -204,7 +204,7 @@ bool TargetMachine::shouldAssumeDSOLocal(const GlobalValue *GV) const { // don't assume the variables to be DSO local unless we actually know // that for sure. This only has to be done for variables; for functions // the linker can insert thunks for calling functions from another DLL. - if (TT.isWindowsGNUEnvironment() && GV->isDeclarationForLinker() && + if (TT.isOSCygMing() && GV->isDeclarationForLinker() && isa<GlobalVariable>(GV)) return false; diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86ATTInstPrinter.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86ATTInstPrinter.cpp index b67c573..abe0cc6 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86ATTInstPrinter.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86ATTInstPrinter.cpp @@ -140,8 +140,8 @@ bool X86ATTInstPrinter::printVecCompareInstr(const MCInst *MI, case X86::VCMPPSZ128rmik: case X86::VCMPPSZ128rrik: case X86::VCMPPSZ256rmik: case X86::VCMPPSZ256rrik: case X86::VCMPPSZrmik: case X86::VCMPPSZrrik: - case X86::VCMPSDZrmi_Intk: case X86::VCMPSDZrri_Intk: - case X86::VCMPSSZrmi_Intk: case X86::VCMPSSZrri_Intk: + case X86::VCMPSDZrmik_Int: case X86::VCMPSDZrrik_Int: + case X86::VCMPSSZrmik_Int: case X86::VCMPSSZrrik_Int: case X86::VCMPPDZ128rmbi: case X86::VCMPPDZ128rmbik: case X86::VCMPPDZ256rmbi: case X86::VCMPPDZ256rmbik: case X86::VCMPPDZrmbi: case X86::VCMPPDZrmbik: @@ -150,8 +150,8 @@ bool X86ATTInstPrinter::printVecCompareInstr(const MCInst *MI, case X86::VCMPPSZrmbi: case X86::VCMPPSZrmbik: case X86::VCMPPDZrrib: case X86::VCMPPDZrribk: case X86::VCMPPSZrrib: case X86::VCMPPSZrribk: - case X86::VCMPSDZrrib_Int: case X86::VCMPSDZrrib_Intk: - case X86::VCMPSSZrrib_Int: case X86::VCMPSSZrrib_Intk: + case X86::VCMPSDZrrib_Int: case X86::VCMPSDZrribk_Int: + case X86::VCMPSSZrrib_Int: case X86::VCMPSSZrribk_Int: case X86::VCMPPHZ128rmi: case X86::VCMPPHZ128rri: case X86::VCMPPHZ256rmi: case X86::VCMPPHZ256rri: case X86::VCMPPHZrmi: case X86::VCMPPHZrri: @@ -160,12 +160,12 @@ bool X86ATTInstPrinter::printVecCompareInstr(const MCInst *MI, case X86::VCMPPHZ128rmik: case X86::VCMPPHZ128rrik: case X86::VCMPPHZ256rmik: case X86::VCMPPHZ256rrik: case X86::VCMPPHZrmik: case X86::VCMPPHZrrik: - case X86::VCMPSHZrmi_Intk: case X86::VCMPSHZrri_Intk: + case X86::VCMPSHZrmik_Int: case X86::VCMPSHZrrik_Int: case X86::VCMPPHZ128rmbi: case X86::VCMPPHZ128rmbik: case X86::VCMPPHZ256rmbi: case X86::VCMPPHZ256rmbik: case X86::VCMPPHZrmbi: case X86::VCMPPHZrmbik: case X86::VCMPPHZrrib: case X86::VCMPPHZrribk: - case X86::VCMPSHZrrib_Int: case X86::VCMPSHZrrib_Intk: + case X86::VCMPSHZrrib_Int: case X86::VCMPSHZrribk_Int: case X86::VCMPPBF16Z128rmi: case X86::VCMPPBF16Z128rri: case X86::VCMPPBF16Z256rmi: case X86::VCMPPBF16Z256rri: case X86::VCMPPBF16Zrmi: case X86::VCMPPBF16Zrri: diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstComments.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86InstComments.cpp index 9f8bc57..681d0da 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstComments.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstComments.cpp @@ -40,6 +40,17 @@ using namespace llvm; CASE_MASK_INS_COMMON(Inst, Suffix, src) \ CASE_MASKZ_INS_COMMON(Inst, Suffix, src) +#define CASE_MASK_INS_COMMON_INT(Inst, Suffix, src) \ + case X86::V##Inst##Suffix##src##k_Int: + +#define CASE_MASKZ_INS_COMMON_INT(Inst, Suffix, src) \ + case X86::V##Inst##Suffix##src##kz_Int: + +#define CASE_AVX512_INS_COMMON_INT(Inst, Suffix, src) \ + CASE_AVX_INS_COMMON(Inst, Suffix, src##_Int) \ + CASE_MASK_INS_COMMON_INT(Inst, Suffix, src) \ + CASE_MASKZ_INS_COMMON_INT(Inst, Suffix, src) + #define CASE_FPCLASS_PACKED(Inst, src) \ CASE_AVX_INS_COMMON(Inst, Z, src##i) \ CASE_AVX_INS_COMMON(Inst, Z256, src##i) \ @@ -196,8 +207,8 @@ using namespace llvm; CASE_AVX_INS_COMMON(Inst##SS, , r_Int) \ CASE_AVX_INS_COMMON(Inst##SD, Z, r) \ CASE_AVX_INS_COMMON(Inst##SS, Z, r) \ - CASE_AVX512_INS_COMMON(Inst##SD, Z, r_Int) \ - CASE_AVX512_INS_COMMON(Inst##SS, Z, r_Int) + CASE_AVX512_INS_COMMON_INT(Inst##SD, Z, r) \ + CASE_AVX512_INS_COMMON_INT(Inst##SS, Z, r) #define CASE_FMA_SCALAR_MEM(Inst) \ CASE_AVX_INS_COMMON(Inst##SD, , m) \ @@ -206,8 +217,8 @@ using namespace llvm; CASE_AVX_INS_COMMON(Inst##SS, , m_Int) \ CASE_AVX_INS_COMMON(Inst##SD, Z, m) \ CASE_AVX_INS_COMMON(Inst##SS, Z, m) \ - CASE_AVX512_INS_COMMON(Inst##SD, Z, m_Int) \ - CASE_AVX512_INS_COMMON(Inst##SS, Z, m_Int) + CASE_AVX512_INS_COMMON_INT(Inst##SD, Z, m) \ + CASE_AVX512_INS_COMMON_INT(Inst##SS, Z, m) #define CASE_FMA4(Inst, suf) \ CASE_AVX_INS_COMMON(Inst, 4, suf) \ diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp index fafcc73..01e2d4ac 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp @@ -277,8 +277,8 @@ void X86InstPrinterCommon::printCMPMnemonic(const MCInst *MI, bool IsVCmp, case X86::VCMPSDrmi_Int: case X86::VCMPSDrri_Int: case X86::VCMPSDZrmi: case X86::VCMPSDZrri: case X86::VCMPSDZrmi_Int: case X86::VCMPSDZrri_Int: - case X86::VCMPSDZrmi_Intk: case X86::VCMPSDZrri_Intk: - case X86::VCMPSDZrrib_Int: case X86::VCMPSDZrrib_Intk: + case X86::VCMPSDZrmik_Int: case X86::VCMPSDZrrik_Int: + case X86::VCMPSDZrrib_Int: case X86::VCMPSDZrribk_Int: OS << "sd\t"; break; case X86::CMPSSrmi: case X86::CMPSSrri: @@ -287,8 +287,8 @@ void X86InstPrinterCommon::printCMPMnemonic(const MCInst *MI, bool IsVCmp, case X86::VCMPSSrmi_Int: case X86::VCMPSSrri_Int: case X86::VCMPSSZrmi: case X86::VCMPSSZrri: case X86::VCMPSSZrmi_Int: case X86::VCMPSSZrri_Int: - case X86::VCMPSSZrmi_Intk: case X86::VCMPSSZrri_Intk: - case X86::VCMPSSZrrib_Int: case X86::VCMPSSZrrib_Intk: + case X86::VCMPSSZrmik_Int: case X86::VCMPSSZrrik_Int: + case X86::VCMPSSZrrib_Int: case X86::VCMPSSZrribk_Int: OS << "ss\t"; break; case X86::VCMPPHZ128rmi: case X86::VCMPPHZ128rri: @@ -305,8 +305,8 @@ void X86InstPrinterCommon::printCMPMnemonic(const MCInst *MI, bool IsVCmp, break; case X86::VCMPSHZrmi: case X86::VCMPSHZrri: case X86::VCMPSHZrmi_Int: case X86::VCMPSHZrri_Int: - case X86::VCMPSHZrrib_Int: case X86::VCMPSHZrrib_Intk: - case X86::VCMPSHZrmi_Intk: case X86::VCMPSHZrri_Intk: + case X86::VCMPSHZrrib_Int: case X86::VCMPSHZrribk_Int: + case X86::VCMPSHZrmik_Int: case X86::VCMPSHZrrik_Int: OS << "sh\t"; break; case X86::VCMPPBF16Z128rmi: case X86::VCMPPBF16Z128rri: diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86IntelInstPrinter.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86IntelInstPrinter.cpp index 6800926..c26dc2c 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86IntelInstPrinter.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86IntelInstPrinter.cpp @@ -119,8 +119,8 @@ bool X86IntelInstPrinter::printVecCompareInstr(const MCInst *MI, raw_ostream &OS case X86::VCMPPSZ128rmik: case X86::VCMPPSZ128rrik: case X86::VCMPPSZ256rmik: case X86::VCMPPSZ256rrik: case X86::VCMPPSZrmik: case X86::VCMPPSZrrik: - case X86::VCMPSDZrmi_Intk: case X86::VCMPSDZrri_Intk: - case X86::VCMPSSZrmi_Intk: case X86::VCMPSSZrri_Intk: + case X86::VCMPSDZrmik_Int: case X86::VCMPSDZrrik_Int: + case X86::VCMPSSZrmik_Int: case X86::VCMPSSZrrik_Int: case X86::VCMPPDZ128rmbi: case X86::VCMPPDZ128rmbik: case X86::VCMPPDZ256rmbi: case X86::VCMPPDZ256rmbik: case X86::VCMPPDZrmbi: case X86::VCMPPDZrmbik: @@ -129,8 +129,8 @@ bool X86IntelInstPrinter::printVecCompareInstr(const MCInst *MI, raw_ostream &OS case X86::VCMPPSZrmbi: case X86::VCMPPSZrmbik: case X86::VCMPPDZrrib: case X86::VCMPPDZrribk: case X86::VCMPPSZrrib: case X86::VCMPPSZrribk: - case X86::VCMPSDZrrib_Int: case X86::VCMPSDZrrib_Intk: - case X86::VCMPSSZrrib_Int: case X86::VCMPSSZrrib_Intk: + case X86::VCMPSDZrrib_Int: case X86::VCMPSDZrribk_Int: + case X86::VCMPSSZrrib_Int: case X86::VCMPSSZrribk_Int: case X86::VCMPPHZ128rmi: case X86::VCMPPHZ128rri: case X86::VCMPPHZ256rmi: case X86::VCMPPHZ256rri: case X86::VCMPPHZrmi: case X86::VCMPPHZrri: @@ -139,12 +139,12 @@ bool X86IntelInstPrinter::printVecCompareInstr(const MCInst *MI, raw_ostream &OS case X86::VCMPPHZ128rmik: case X86::VCMPPHZ128rrik: case X86::VCMPPHZ256rmik: case X86::VCMPPHZ256rrik: case X86::VCMPPHZrmik: case X86::VCMPPHZrrik: - case X86::VCMPSHZrmi_Intk: case X86::VCMPSHZrri_Intk: + case X86::VCMPSHZrmik_Int: case X86::VCMPSHZrrik_Int: case X86::VCMPPHZ128rmbi: case X86::VCMPPHZ128rmbik: case X86::VCMPPHZ256rmbi: case X86::VCMPPHZ256rmbik: case X86::VCMPPHZrmbi: case X86::VCMPPHZrmbik: case X86::VCMPPHZrrib: case X86::VCMPPHZrribk: - case X86::VCMPSHZrrib_Int: case X86::VCMPSHZrrib_Intk: + case X86::VCMPSHZrrib_Int: case X86::VCMPSHZrribk_Int: case X86::VCMPPBF16Z128rmi: case X86::VCMPPBF16Z128rri: case X86::VCMPPBF16Z256rmi: case X86::VCMPPBF16Z256rri: case X86::VCMPPBF16Zrmi: case X86::VCMPPBF16Zrri: diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index e7f6032e..6b0eb38 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -94,7 +94,7 @@ static cl::opt<int> BrMergingCcmpBias( static cl::opt<bool> WidenShift("x86-widen-shift", cl::init(true), - cl::desc("Replacte narrow shifts with wider shifts."), + cl::desc("Replace narrow shifts with wider shifts."), cl::Hidden); static cl::opt<int> BrMergingLikelyBias( @@ -341,8 +341,17 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } } if (Subtarget.hasAVX10_2()) { - setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i32, Legal); - setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i32, Legal); + setOperationAction(ISD::FP_TO_UINT_SAT, MVT::v2i32, Custom); + setOperationAction(ISD::FP_TO_SINT_SAT, MVT::v2i32, Custom); + for (MVT VT : {MVT::i32, MVT::v4i32, MVT::v8i32, MVT::v16i32, MVT::v2i64, + MVT::v4i64}) { + setOperationAction(ISD::FP_TO_UINT_SAT, VT, Legal); + setOperationAction(ISD::FP_TO_SINT_SAT, VT, Legal); + } + if (Subtarget.hasAVX10_2_512()) { + setOperationAction(ISD::FP_TO_UINT_SAT, MVT::v8i64, Legal); + setOperationAction(ISD::FP_TO_SINT_SAT, MVT::v8i64, Legal); + } if (Subtarget.is64Bit()) { setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Legal); setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Legal); @@ -623,6 +632,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FMAXNUM, VT, Action); setOperationAction(ISD::FMINIMUM, VT, Action); setOperationAction(ISD::FMAXIMUM, VT, Action); + setOperationAction(ISD::FMINIMUMNUM, VT, Action); + setOperationAction(ISD::FMAXIMUMNUM, VT, Action); setOperationAction(ISD::FSIN, VT, Action); setOperationAction(ISD::FCOS, VT, Action); setOperationAction(ISD::FSINCOS, VT, Action); @@ -1066,6 +1077,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FMAXIMUM, MVT::f32, Custom); setOperationAction(ISD::FMINIMUM, MVT::f32, Custom); + setOperationAction(ISD::FMAXIMUMNUM, MVT::f32, Custom); + setOperationAction(ISD::FMINIMUMNUM, MVT::f32, Custom); setOperationAction(ISD::FNEG, MVT::v4f32, Custom); setOperationAction(ISD::FABS, MVT::v4f32, Custom); @@ -1108,6 +1121,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (auto VT : { MVT::f64, MVT::v4f32, MVT::v2f64 }) { setOperationAction(ISD::FMAXIMUM, VT, Custom); setOperationAction(ISD::FMINIMUM, VT, Custom); + setOperationAction(ISD::FMAXIMUMNUM, VT, Custom); + setOperationAction(ISD::FMINIMUMNUM, VT, Custom); } for (auto VT : { MVT::v2i8, MVT::v4i8, MVT::v8i8, @@ -1473,6 +1488,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FMAXIMUM, VT, Custom); setOperationAction(ISD::FMINIMUM, VT, Custom); + setOperationAction(ISD::FMAXIMUMNUM, VT, Custom); + setOperationAction(ISD::FMINIMUMNUM, VT, Custom); setOperationAction(ISD::FCANONICALIZE, VT, Custom); } @@ -1818,6 +1835,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (MVT VT : { MVT::v16f32, MVT::v8f64 }) { setOperationAction(ISD::FMAXIMUM, VT, Custom); setOperationAction(ISD::FMINIMUM, VT, Custom); + setOperationAction(ISD::FMAXIMUMNUM, VT, Custom); + setOperationAction(ISD::FMINIMUMNUM, VT, Custom); setOperationAction(ISD::FNEG, VT, Custom); setOperationAction(ISD::FABS, VT, Custom); setOperationAction(ISD::FMA, VT, Legal); @@ -2289,6 +2308,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom); setOperationAction(ISD::FMAXIMUM, MVT::f16, Custom); setOperationAction(ISD::FMINIMUM, MVT::f16, Custom); + setOperationAction(ISD::FMAXIMUMNUM, MVT::f16, Custom); + setOperationAction(ISD::FMINIMUMNUM, MVT::f16, Custom); setOperationAction(ISD::FP_EXTEND, MVT::f32, Legal); setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal); @@ -2336,6 +2357,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FMINIMUM, MVT::v32f16, Custom); setOperationAction(ISD::FMAXIMUM, MVT::v32f16, Custom); + setOperationAction(ISD::FMINIMUMNUM, MVT::v32f16, Custom); + setOperationAction(ISD::FMAXIMUMNUM, MVT::v32f16, Custom); } if (Subtarget.hasVLX()) { @@ -2383,9 +2406,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FMINIMUM, MVT::v8f16, Custom); setOperationAction(ISD::FMAXIMUM, MVT::v8f16, Custom); + setOperationAction(ISD::FMINIMUMNUM, MVT::v8f16, Custom); + setOperationAction(ISD::FMAXIMUMNUM, MVT::v8f16, Custom); setOperationAction(ISD::FMINIMUM, MVT::v16f16, Custom); setOperationAction(ISD::FMAXIMUM, MVT::v16f16, Custom); + setOperationAction(ISD::FMINIMUMNUM, MVT::v16f16, Custom); + setOperationAction(ISD::FMAXIMUMNUM, MVT::v16f16, Custom); } } @@ -2442,6 +2469,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FSQRT, VT, Legal); setOperationAction(ISD::FMA, VT, Legal); setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::FMINIMUM, VT, Custom); + setOperationAction(ISD::FMAXIMUM, VT, Custom); + setOperationAction(ISD::FMINIMUMNUM, VT, Custom); + setOperationAction(ISD::FMAXIMUMNUM, VT, Custom); } if (Subtarget.hasAVX10_2_512()) { setOperationAction(ISD::FADD, MVT::v32bf16, Legal); @@ -2451,6 +2482,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FSQRT, MVT::v32bf16, Legal); setOperationAction(ISD::FMA, MVT::v32bf16, Legal); setOperationAction(ISD::SETCC, MVT::v32bf16, Custom); + setOperationAction(ISD::FMINIMUM, MVT::v32bf16, Custom); + setOperationAction(ISD::FMAXIMUM, MVT::v32bf16, Custom); + setOperationAction(ISD::FMINIMUMNUM, MVT::v32bf16, Custom); + setOperationAction(ISD::FMAXIMUMNUM, MVT::v32bf16, Custom); } for (auto VT : {MVT::f16, MVT::f32, MVT::f64}) { setCondCodeAction(ISD::SETOEQ, VT, Custom); @@ -2652,6 +2687,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, ISD::UINT_TO_FP, ISD::STRICT_SINT_TO_FP, ISD::STRICT_UINT_TO_FP, + ISD::FP_TO_SINT_SAT, + ISD::FP_TO_UINT_SAT, ISD::SETCC, ISD::MUL, ISD::XOR, @@ -28835,19 +28872,35 @@ static SDValue LowerMINMAX(SDValue Op, const X86Subtarget &Subtarget, static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - assert((Op.getOpcode() == ISD::FMAXIMUM || Op.getOpcode() == ISD::FMINIMUM) && - "Expected FMAXIMUM or FMINIMUM opcode"); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT VT = Op.getValueType(); SDValue X = Op.getOperand(0); SDValue Y = Op.getOperand(1); SDLoc DL(Op); + bool IsMaxOp = + Op.getOpcode() == ISD::FMAXIMUM || Op.getOpcode() == ISD::FMAXIMUMNUM; + bool IsNum = + Op.getOpcode() == ISD::FMINIMUMNUM || Op.getOpcode() == ISD::FMAXIMUMNUM; + if (Subtarget.hasAVX10_2() && TLI.isTypeLegal(VT)) { + unsigned Opc = 0; + if (VT.isVector()) + Opc = X86ISD::VMINMAX; + else if (VT == MVT::f16 || VT == MVT::f32 || VT == MVT::f64) + Opc = X86ISD::VMINMAXS; + + if (Opc) { + SDValue Imm = + DAG.getTargetConstant(IsMaxOp + (IsNum ? 16 : 0), DL, MVT::i32); + return DAG.getNode(Opc, DL, VT, X, Y, Imm, Op->getFlags()); + } + } + uint64_t SizeInBits = VT.getScalarSizeInBits(); APInt PreferredZero = APInt::getZero(SizeInBits); APInt OppositeZero = PreferredZero; EVT IVT = VT.changeTypeToInteger(); X86ISD::NodeType MinMaxOp; - if (Op.getOpcode() == ISD::FMAXIMUM) { + if (IsMaxOp) { MinMaxOp = X86ISD::FMAX; OppositeZero.setSignBit(); } else { @@ -28977,7 +29030,9 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget, if (IgnoreNaN || DAG.isKnownNeverNaN(NewX)) return MinMax; - SDValue IsNaN = DAG.getSetCC(DL, SetCCType, NewX, NewX, ISD::SETUO); + SDValue IsNaN = + DAG.getSetCC(DL, SetCCType, NewX, NewX, IsNum ? ISD::SETO : ISD::SETUO); + return DAG.getSelect(DL, VT, IsNaN, NewX, MinMax); } @@ -33235,6 +33290,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::UMIN: return LowerMINMAX(Op, Subtarget, DAG); case ISD::FMINIMUM: case ISD::FMAXIMUM: + case ISD::FMINIMUMNUM: + case ISD::FMAXIMUMNUM: return LowerFMINIMUM_FMAXIMUM(Op, Subtarget, DAG); case ISD::ABS: return LowerABS(Op, Subtarget, DAG); case ISD::ABDS: @@ -33647,6 +33704,26 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, } return; } + case ISD::FP_TO_SINT_SAT: + case ISD::FP_TO_UINT_SAT: { + if (!Subtarget.hasAVX10_2()) + return; + + bool IsSigned = Opc == ISD::FP_TO_SINT_SAT; + EVT VT = N->getValueType(0); + SDValue Op = N->getOperand(0); + EVT OpVT = Op.getValueType(); + SDValue Res; + + if (VT == MVT::v2i32 && OpVT == MVT::v2f64) { + if (IsSigned) + Res = DAG.getNode(X86ISD::FP_TO_SINT_SAT, dl, MVT::v4i32, Op); + else + Res = DAG.getNode(X86ISD::FP_TO_UINT_SAT, dl, MVT::v4i32, Op); + Results.push_back(Res); + } + return; + } case ISD::FP_TO_SINT: case ISD::STRICT_FP_TO_SINT: case ISD::FP_TO_UINT: @@ -34627,6 +34704,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VPERMV3) NODE_NAME_CASE(VPERMI) NODE_NAME_CASE(VPTERNLOG) + NODE_NAME_CASE(FP_TO_SINT_SAT) + NODE_NAME_CASE(FP_TO_UINT_SAT) NODE_NAME_CASE(VFIXUPIMM) NODE_NAME_CASE(VFIXUPIMM_SAE) NODE_NAME_CASE(VFIXUPIMMS) @@ -41615,6 +41694,8 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { MVT VT = N.getSimpleValueType(); + unsigned NumElts = VT.getVectorNumElements(); + SmallVector<int, 4> Mask; unsigned Opcode = N.getOpcode(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -41900,7 +41981,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, APInt Mask = APInt::getHighBitsSet(64, 32); if (DAG.MaskedValueIsZero(In, Mask)) { SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, In); - MVT VecVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() * 2); + MVT VecVT = MVT::getVectorVT(MVT::i32, NumElts * 2); SDValue SclVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Trunc); SDValue Movl = DAG.getNode(X86ISD::VZEXT_MOVL, DL, VecVT, SclVec); return DAG.getBitcast(VT, Movl); @@ -41915,7 +41996,6 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, // Create a vector constant - scalar constant followed by zeros. EVT ScalarVT = N0.getOperand(0).getValueType(); Type *ScalarTy = ScalarVT.getTypeForEVT(*DAG.getContext()); - unsigned NumElts = VT.getVectorNumElements(); Constant *Zero = ConstantInt::getNullValue(ScalarTy); SmallVector<Constant *, 32> ConstantVec(NumElts, Zero); ConstantVec[0] = const_cast<ConstantInt *>(C->getConstantIntValue()); @@ -41966,9 +42046,8 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, MVT SrcVT = N0.getOperand(0).getSimpleValueType(); unsigned SrcBits = SrcVT.getScalarSizeInBits(); if ((EltBits % SrcBits) == 0 && SrcBits >= 32) { - unsigned Size = VT.getVectorNumElements(); unsigned NewSize = SrcVT.getVectorNumElements(); - APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(Size); + APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(NumElts); APInt NewBlendMask = APIntOps::ScaleBitMask(BlendMask, NewSize); return DAG.getBitcast( VT, DAG.getNode(X86ISD::BLENDI, DL, SrcVT, N0.getOperand(0), @@ -42381,7 +42460,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, int DOffset = N.getOpcode() == X86ISD::PSHUFLW ? 0 : 2; DMask[DOffset + 0] = DOffset + 1; DMask[DOffset + 1] = DOffset + 0; - MVT DVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2); + MVT DVT = MVT::getVectorVT(MVT::i32, NumElts / 2); V = DAG.getBitcast(DVT, V); V = DAG.getNode(X86ISD::PSHUFD, DL, DVT, V, getV4X86ShuffleImm8ForMask(DMask, DL, DAG)); @@ -45976,6 +46055,8 @@ static SDValue scalarizeExtEltFP(SDNode *ExtElt, SelectionDAG &DAG, case ISD::FMAXNUM_IEEE: case ISD::FMAXIMUM: case ISD::FMINIMUM: + case ISD::FMAXIMUMNUM: + case ISD::FMINIMUMNUM: case X86ISD::FMAX: case X86ISD::FMIN: case ISD::FABS: // Begin 1 operand @@ -56184,6 +56265,33 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG, return SDValue(); } +// Custom handling for VCVTTPS2QQS/VCVTTPS2UQQS +static SDValue combineFP_TO_xINT_SAT(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasAVX10_2()) + return SDValue(); + + bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT_SAT; + EVT SrcVT = N->getOperand(0).getValueType(); + EVT DstVT = N->getValueType(0); + SDLoc dl(N); + + if (SrcVT == MVT::v2f32 && DstVT == MVT::v2i64) { + SDValue V2F32Value = DAG.getUNDEF(SrcVT); + + // Concatenate the original v2f32 input and V2F32Value to create v4f32 + SDValue NewSrc = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, + N->getOperand(0), V2F32Value); + + // Select the FP_TO_SINT_SAT/FP_TO_UINT_SAT node + if (IsSigned) + return DAG.getNode(X86ISD::FP_TO_SINT_SAT, dl, MVT::v2i64, NewSrc); + + return DAG.getNode(X86ISD::FP_TO_UINT_SAT, dl, MVT::v2i64, NewSrc); + } + return SDValue(); +} + static bool needCarryOrOverflowFlag(SDValue Flags) { assert(Flags.getValueType() == MVT::i32 && "Unexpected VT!"); @@ -59297,6 +59405,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::INTRINSIC_WO_CHAIN: return combineINTRINSIC_WO_CHAIN(N, DAG, DCI); case ISD::INTRINSIC_W_CHAIN: return combineINTRINSIC_W_CHAIN(N, DAG, DCI); case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI); + case ISD::FP_TO_SINT_SAT: + case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget); // clang-format on } diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h index 2b7a8ea..eaedaa0 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -908,6 +908,10 @@ namespace llvm { // Load x87 FPU environment from memory. FLDENVm, + // Custom handling for FP_TO_xINT_SAT + FP_TO_SINT_SAT, + FP_TO_UINT_SAT, + /// This instruction implements FP_TO_SINT with the /// integer destination in memory and a FP reg source. This corresponds /// to the X86::FIST*m instructions and the rounding mode change stuff. It diff --git a/llvm/lib/Target/X86/X86InstrAVX10.td b/llvm/lib/Target/X86/X86InstrAVX10.td index 0301c07..1270161 100644 --- a/llvm/lib/Target/X86/X86InstrAVX10.td +++ b/llvm/lib/Target/X86/X86InstrAVX10.td @@ -403,28 +403,45 @@ multiclass avx10_minmax_scalar<string OpStr, X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE> { let ExeDomain = _.ExeDomain, Predicates = [HasAVX10_2] in { let mayRaiseFPException = 1 in { + let isCodeGenOnly = 1 in { + def rri : AVX512Ii8<0x53, MRMSrcReg, (outs _.FRC:$dst), + (ins _.FRC:$src1, _.FRC:$src2, i32u8imm:$src3), + !strconcat(OpStr, "\t{$src3, $src2, $src1|$src1, $src2, $src3}"), + [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2, (i32 timm:$src3)))]>, + Sched<[WriteFMAX]>; + + def rmi : AVX512Ii8<0x53, MRMSrcMem, (outs _.FRC:$dst), + (ins _.FRC:$src1, _.ScalarMemOp:$src2, i32u8imm:$src3), + !strconcat(OpStr, "\t{$src3, $src2, $src1|$src1, $src2, $src3}"), + [(set _.FRC:$dst, (OpNode _.FRC:$src1, (_.ScalarLdFrag addr:$src2), + (i32 timm:$src3)))]>, + Sched<[WriteFMAX.Folded, WriteFMAX.ReadAfterFold]>; + } defm rri : AVX512_maskable<0x53, MRMSrcReg, _, (outs VR128X:$dst), - (ins VR128X:$src1, VR128X:$src2, i32u8imm:$src3), - OpStr, "$src3, $src2, $src1", "$src1, $src2, $src3", - (_.VT (OpNode (_.VT _.RC:$src1), (_.VT _.RC:$src2), - (i32 timm:$src3)))>, - Sched<[WriteFMAX]>; + (ins VR128X:$src1, VR128X:$src2, i32u8imm:$src3), + OpStr, "$src3, $src2, $src1", "$src1, $src2, $src3", + (_.VT (OpNode (_.VT _.RC:$src1), (_.VT _.RC:$src2), + (i32 timm:$src3))), + 0, 0, 0, vselect_mask, "", "_Int">, + Sched<[WriteFMAX]>; defm rmi : AVX512_maskable<0x53, MRMSrcMem, _, (outs VR128X:$dst), - (ins VR128X:$src1, _.ScalarMemOp:$src2, i32u8imm:$src3), - OpStr, "$src3, $src2, $src1", "$src1, $src2, $src3", - (_.VT (OpNode (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2), - (i32 timm:$src3)))>, + (ins VR128X:$src1, _.ScalarMemOp:$src2, i32u8imm:$src3), + OpStr, "$src3, $src2, $src1", "$src1, $src2, $src3", + (_.VT (OpNode (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2), + (i32 timm:$src3))), + 0, 0, 0, vselect_mask, "", "_Int">, Sched<[WriteFMAX.Folded, WriteFMAX.ReadAfterFold]>; } let Uses = []<Register>, mayRaiseFPException = 0 in defm rrib : AVX512_maskable<0x53, MRMSrcReg, _, (outs VR128X:$dst), - (ins VR128X:$src1, VR128X:$src2, i32u8imm:$src3), - OpStr, "$src3, {sae}, $src2, $src1", - "$src1, $src2, {sae}, $src3", - (_.VT (OpNodeSAE (_.VT _.RC:$src1), (_.VT _.RC:$src2), - (i32 timm:$src3)))>, - Sched<[WriteFMAX]>, EVEX_B; + (ins VR128X:$src1, VR128X:$src2, i32u8imm:$src3), + OpStr, "$src3, {sae}, $src2, $src1", + "$src1, $src2, {sae}, $src3", + (_.VT (OpNodeSAE (_.VT _.RC:$src1), (_.VT _.RC:$src2), + (i32 timm:$src3))), + 0, 0, 0, vselect_mask, "", "_Int">, + Sched<[WriteFMAX]>, EVEX_B; } } @@ -817,6 +834,70 @@ let Predicates = [HasAVX10_2] in { // patterns have been disabled with null_frag. // Patterns VCVTTPD2DQSZ128 +// VCVTTPD2DQS +def : Pat<(v4i32(X86fp2sisat(v2f64 VR128X:$src))), + (VCVTTPD2DQSZ128rr VR128X:$src)>; +def : Pat<(v4i32(fp_to_sint_sat(v4f64 VR256X:$src), i32)), + (VCVTTPD2DQSZ256rr VR256X:$src)>; +def : Pat<(v8i32(fp_to_sint_sat(v8f64 VR512:$src), i32)), + (VCVTTPD2DQSZrr VR512:$src)>; + +// VCVTTPD2QQS +def : Pat<(v2i64(fp_to_sint_sat(v2f64 VR128X:$src), i64)), + (VCVTTPD2QQSZ128rr VR128X:$src)>; +def : Pat<(v4i64(fp_to_sint_sat(v4f64 VR256X:$src), i64)), + (VCVTTPD2QQSZ256rr VR256X:$src)>; +def : Pat<(v8i64(fp_to_sint_sat(v8f64 VR512:$src), i64)), + (VCVTTPD2QQSZrr VR512:$src)>; + +// VCVTTPD2UDQS +def : Pat<(v4i32(X86fp2uisat(v2f64 VR128X:$src))), + (VCVTTPD2UDQSZ128rr VR128X:$src)>; +def : Pat<(v4i32(fp_to_uint_sat(v4f64 VR256X:$src), i32)), + (VCVTTPD2UDQSZ256rr VR256X:$src)>; +def : Pat<(v8i32(fp_to_uint_sat(v8f64 VR512:$src), i32)), + (VCVTTPD2UDQSZrr VR512:$src)>; + +// VCVTTPD2UQQS +def : Pat<(v2i64(fp_to_uint_sat(v2f64 VR128X:$src), i64)), + (VCVTTPD2UQQSZ128rr VR128X:$src)>; +def : Pat<(v4i64(fp_to_uint_sat(v4f64 VR256X:$src), i64)), + (VCVTTPD2UQQSZ256rr VR256X:$src)>; +def : Pat<(v8i64(fp_to_uint_sat(v8f64 VR512:$src), i64)), + (VCVTTPD2UQQSZrr VR512:$src)>; + +// VCVTTPS2DQS +def : Pat<(v4i32(fp_to_sint_sat(v4f32 VR128X:$src), i32)), + (VCVTTPS2DQSZ128rr VR128X:$src)>; +def : Pat<(v8i32(fp_to_sint_sat(v8f32 VR256X:$src), i32)), + (VCVTTPS2DQSZ256rr VR256X:$src)>; +def : Pat<(v16i32(fp_to_sint_sat(v16f32 VR512:$src), i32)), + (VCVTTPS2DQSZrr VR512:$src)>; + +// VCVTTPS2QQS +def : Pat<(v2i64(X86fp2sisat(v4f32 VR128X:$src))), + (VCVTTPS2QQSZ128rr VR128X:$src)>; +def : Pat<(v4i64(fp_to_sint_sat(v4f32 VR128X:$src), i64)), + (VCVTTPS2QQSZ256rr VR128X:$src)>; +def : Pat<(v8i64(fp_to_sint_sat(v8f32 VR256X:$src), i64)), + (VCVTTPS2QQSZrr VR256X:$src)>; + +// VCVTTPS2UDQS +def : Pat<(v4i32(fp_to_uint_sat(v4f32 VR128X:$src), i32)), + (VCVTTPS2UDQSZ128rr VR128X:$src)>; +def : Pat<(v8i32(fp_to_uint_sat(v8f32 VR256X:$src), i32)), + (VCVTTPS2UDQSZ256rr VR256X:$src)>; +def : Pat<(v16i32(fp_to_uint_sat(v16f32 VR512:$src), i32)), + (VCVTTPS2UDQSZrr VR512:$src)>; + +// VCVTTPS2UQQS +def : Pat<(v2i64(X86fp2uisat(v4f32 VR128X:$src))), + (VCVTTPS2UQQSZ128rr VR128X:$src)>; +def : Pat<(v4i64(fp_to_uint_sat(v4f32 VR128X:$src), i64)), + (VCVTTPS2UQQSZ256rr VR128X:$src)>; +def : Pat<(v8i64(fp_to_uint_sat(v8f32 VR256X:$src), i64)), + (VCVTTPS2UQQSZrr VR256X:$src)>; + def : Pat<(v4i32 (X86cvttp2sis (v2f64 VR128X:$src))), (VCVTTPD2DQSZ128rr VR128X:$src)>; def : Pat<(v4i32 (X86cvttp2sis (loadv2f64 addr:$src))), diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index e899807..d6ca4b1 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -28,19 +28,20 @@ multiclass AVX512_maskable_custom<bits<8> O, Format F, bit IsCommutable = 0, bit IsKCommutable = 0, bit IsKZCommutable = IsCommutable, - string ClobberConstraint = ""> { + string ClobberConstraint = "", + string Suffix = ""> { let isCommutable = IsCommutable, Constraints = ClobberConstraint in - def NAME: AVX512<O, F, Outs, Ins, - OpcodeStr#"\t{"#AttSrcAsm#", $dst|"# - "$dst, "#IntelSrcAsm#"}", - Pattern>; + def Suffix: AVX512<O, F, Outs, Ins, + OpcodeStr#"\t{"#AttSrcAsm#", $dst|"# + "$dst, "#IntelSrcAsm#"}", + Pattern>; // Prefer over VMOV*rrk Pat<> let isCommutable = IsKCommutable in - def NAME#k: AVX512<O, F, Outs, MaskingIns, - OpcodeStr#"\t{"#AttSrcAsm#", $dst {${mask}}|"# - "$dst {${mask}}, "#IntelSrcAsm#"}", - MaskingPattern>, + def k#Suffix: AVX512<O, F, Outs, MaskingIns, + OpcodeStr#"\t{"#AttSrcAsm#", $dst {${mask}}|"# + "$dst {${mask}}, "#IntelSrcAsm#"}", + MaskingPattern>, EVEX_K { // In case of the 3src subclass this is overridden with a let. string Constraints = !if(!eq(ClobberConstraint, ""), MaskingConstraint, @@ -52,10 +53,10 @@ multiclass AVX512_maskable_custom<bits<8> O, Format F, // So, it is Ok to use IsCommutable instead of IsKCommutable. let isCommutable = IsKZCommutable, // Prefer over VMOV*rrkz Pat<> Constraints = ClobberConstraint in - def NAME#kz: AVX512<O, F, Outs, ZeroMaskingIns, - OpcodeStr#"\t{"#AttSrcAsm#", $dst {${mask}} {z}|"# - "$dst {${mask}} {z}, "#IntelSrcAsm#"}", - ZeroMaskingPattern>, + def kz#Suffix: AVX512<O, F, Outs, ZeroMaskingIns, + OpcodeStr#"\t{"#AttSrcAsm#", $dst {${mask}} {z}|"# + "$dst {${mask}} {z}, "#IntelSrcAsm#"}", + ZeroMaskingPattern>, EVEX_KZ; } @@ -72,7 +73,8 @@ multiclass AVX512_maskable_common<bits<8> O, Format F, X86VectorVTInfo _, bit IsCommutable = 0, bit IsKCommutable = 0, bit IsKZCommutable = IsCommutable, - string ClobberConstraint = ""> : + string ClobberConstraint = "", + string Suffix = ""> : AVX512_maskable_custom<O, F, Outs, Ins, MaskingIns, ZeroMaskingIns, OpcodeStr, AttSrcAsm, IntelSrcAsm, [(set _.RC:$dst, RHS)], @@ -80,7 +82,8 @@ multiclass AVX512_maskable_common<bits<8> O, Format F, X86VectorVTInfo _, [(set _.RC:$dst, (Select _.KRCWM:$mask, RHS, _.ImmAllZerosV))], MaskingConstraint, IsCommutable, - IsKCommutable, IsKZCommutable, ClobberConstraint>; + IsKCommutable, IsKZCommutable, ClobberConstraint, + Suffix>; // This multiclass generates the unconditional/non-masking, the masking and // the zero-masking variant of the vector instruction. In the masking case, the @@ -115,23 +118,24 @@ multiclass AVX512_maskable<bits<8> O, Format F, X86VectorVTInfo _, bit IsCommutable = 0, bit IsKCommutable = 0, bit IsKZCommutable = IsCommutable, SDPatternOperator Select = vselect_mask, - string ClobberConstraint = ""> : + string ClobberConstraint = "", + string Suffix = ""> : AVX512_maskable_common<O, F, _, Outs, Ins, !con((ins _.RC:$src0, _.KRCWM:$mask), Ins), !con((ins _.KRCWM:$mask), Ins), OpcodeStr, AttSrcAsm, IntelSrcAsm, RHS, (Select _.KRCWM:$mask, RHS, _.RC:$src0), Select, "$src0 = $dst", IsCommutable, IsKCommutable, - IsKZCommutable, ClobberConstraint>; + IsKZCommutable, ClobberConstraint, Suffix>; // This multiclass generates the unconditional/non-masking, the masking and // the zero-masking variant of the scalar instruction. multiclass AVX512_maskable_scalar<bits<8> O, Format F, X86VectorVTInfo _, dag Outs, dag Ins, string OpcodeStr, string AttSrcAsm, string IntelSrcAsm, - dag RHS> : + dag RHS, string Suffix = ""> : AVX512_maskable<O, F, _, Outs, Ins, OpcodeStr, AttSrcAsm, IntelSrcAsm, - RHS, 0, 0, 0, X86selects_mask>; + RHS, 0, 0, 0, X86selects_mask, "", Suffix>; // Similar to AVX512_maskable but in this case one of the source operands // ($src1) is already tied to $dst so we just use that for the preserved @@ -144,7 +148,7 @@ multiclass AVX512_maskable_3src<bits<8> O, Format F, X86VectorVTInfo _, bit IsCommutable = 0, bit IsKCommutable = 0, SDPatternOperator Select = vselect_mask, - bit MaskOnly = 0> : + bit MaskOnly = 0, string Suffix = ""> : AVX512_maskable_common<O, F, _, Outs, !con((ins _.RC:$src1), NonTiedIns), !con((ins _.RC:$src1, _.KRCWM:$mask), NonTiedIns), @@ -152,7 +156,8 @@ multiclass AVX512_maskable_3src<bits<8> O, Format F, X86VectorVTInfo _, OpcodeStr, AttSrcAsm, IntelSrcAsm, !if(MaskOnly, (null_frag), RHS), (Select _.KRCWM:$mask, RHS, _.RC:$src1), - Select, "", IsCommutable, IsKCommutable>; + Select, "", IsCommutable, IsKCommutable, + IsCommutable, "", Suffix>; // Similar to AVX512_maskable_3src but in this case the input VT for the tied // operand differs from the output VT. This requires a bitconvert on @@ -178,10 +183,10 @@ multiclass AVX512_maskable_3src_scalar<bits<8> O, Format F, X86VectorVTInfo _, dag RHS, bit IsCommutable = 0, bit IsKCommutable = 0, - bit MaskOnly = 0> : + bit MaskOnly = 0, string Suffix = ""> : AVX512_maskable_3src<O, F, _, Outs, NonTiedIns, OpcodeStr, AttSrcAsm, IntelSrcAsm, RHS, IsCommutable, IsKCommutable, - X86selects_mask, MaskOnly>; + X86selects_mask, MaskOnly, Suffix>; multiclass AVX512_maskable_in_asm<bits<8> O, Format F, X86VectorVTInfo _, dag Outs, dag Ins, @@ -215,17 +220,18 @@ multiclass AVX512_maskable_custom_cmp<bits<8> O, Format F, string AttSrcAsm, string IntelSrcAsm, list<dag> Pattern, list<dag> MaskingPattern, - bit IsCommutable = 0> { + bit IsCommutable = 0, + string Suffix = ""> { let isCommutable = IsCommutable in { - def NAME: AVX512<O, F, Outs, Ins, + def Suffix: AVX512<O, F, Outs, Ins, OpcodeStr#"\t{"#AttSrcAsm#", $dst|"# "$dst, "#IntelSrcAsm#"}", Pattern>; - def NAME#k: AVX512<O, F, Outs, MaskingIns, - OpcodeStr#"\t{"#AttSrcAsm#", $dst {${mask}}|"# - "$dst {${mask}}, "#IntelSrcAsm#"}", - MaskingPattern>, EVEX_K; + def k#Suffix: AVX512<O, F, Outs, MaskingIns, + OpcodeStr#"\t{"#AttSrcAsm#", $dst {${mask}}|"# + "$dst {${mask}}, "#IntelSrcAsm#"}", + MaskingPattern>, EVEX_K; } } @@ -235,20 +241,22 @@ multiclass AVX512_maskable_common_cmp<bits<8> O, Format F, X86VectorVTInfo _, string OpcodeStr, string AttSrcAsm, string IntelSrcAsm, dag RHS, dag MaskingRHS, - bit IsCommutable = 0> : + bit IsCommutable = 0, + string Suffix = ""> : AVX512_maskable_custom_cmp<O, F, Outs, Ins, MaskingIns, OpcodeStr, AttSrcAsm, IntelSrcAsm, [(set _.KRC:$dst, RHS)], - [(set _.KRC:$dst, MaskingRHS)], IsCommutable>; + [(set _.KRC:$dst, MaskingRHS)], IsCommutable, Suffix>; multiclass AVX512_maskable_cmp<bits<8> O, Format F, X86VectorVTInfo _, dag Outs, dag Ins, string OpcodeStr, string AttSrcAsm, string IntelSrcAsm, - dag RHS, dag RHS_su, bit IsCommutable = 0> : + dag RHS, dag RHS_su, bit IsCommutable = 0, + string Suffix = ""> : AVX512_maskable_common_cmp<O, F, _, Outs, Ins, !con((ins _.KRCWM:$mask), Ins), OpcodeStr, AttSrcAsm, IntelSrcAsm, RHS, - (and _.KRCWM:$mask, RHS_su), IsCommutable>; + (and _.KRCWM:$mask, RHS_su), IsCommutable, Suffix>; // Used by conversion instructions. multiclass AVX512_maskable_cvt<bits<8> O, Format F, X86VectorVTInfo _, @@ -1937,37 +1945,37 @@ defm VPBLENDMW : blendmask_bw<0x66, "vpblendmw", SchedWriteVarBlend, multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE, PatFrag OpNode_su, PatFrag OpNodeSAE_su, X86FoldableSchedWrite sched> { - defm rri_Int : AVX512_maskable_cmp<0xC2, MRMSrcReg, _, - (outs _.KRC:$dst), - (ins _.RC:$src1, _.RC:$src2, u8imm:$cc), - "vcmp"#_.Suffix, - "$cc, $src2, $src1", "$src1, $src2, $cc", - (OpNode (_.VT _.RC:$src1), (_.VT _.RC:$src2), timm:$cc), - (OpNode_su (_.VT _.RC:$src1), (_.VT _.RC:$src2), timm:$cc)>, - EVEX, VVVV, VEX_LIG, Sched<[sched]>, SIMD_EXC; + defm rri : AVX512_maskable_cmp<0xC2, MRMSrcReg, _, + (outs _.KRC:$dst), + (ins _.RC:$src1, _.RC:$src2, u8imm:$cc), + "vcmp"#_.Suffix, + "$cc, $src2, $src1", "$src1, $src2, $cc", + (OpNode (_.VT _.RC:$src1), (_.VT _.RC:$src2), timm:$cc), + (OpNode_su (_.VT _.RC:$src1), (_.VT _.RC:$src2), timm:$cc), 0, "_Int">, + EVEX, VVVV, VEX_LIG, Sched<[sched]>, SIMD_EXC; let mayLoad = 1 in - defm rmi_Int : AVX512_maskable_cmp<0xC2, MRMSrcMem, _, - (outs _.KRC:$dst), - (ins _.RC:$src1, _.IntScalarMemOp:$src2, u8imm:$cc), - "vcmp"#_.Suffix, - "$cc, $src2, $src1", "$src1, $src2, $cc", - (OpNode (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2), - timm:$cc), - (OpNode_su (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2), - timm:$cc)>, EVEX, VVVV, VEX_LIG, EVEX_CD8<_.EltSize, CD8VT1>, - Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC; + defm rmi : AVX512_maskable_cmp<0xC2, MRMSrcMem, _, + (outs _.KRC:$dst), + (ins _.RC:$src1, _.IntScalarMemOp:$src2, u8imm:$cc), + "vcmp"#_.Suffix, + "$cc, $src2, $src1", "$src1, $src2, $cc", + (OpNode (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2), + timm:$cc), + (OpNode_su (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2), + timm:$cc), 0, "_Int">, EVEX, VVVV, VEX_LIG, EVEX_CD8<_.EltSize, CD8VT1>, + Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC; let Uses = [MXCSR] in - defm rrib_Int : AVX512_maskable_cmp<0xC2, MRMSrcReg, _, - (outs _.KRC:$dst), - (ins _.RC:$src1, _.RC:$src2, u8imm:$cc), - "vcmp"#_.Suffix, - "$cc, {sae}, $src2, $src1","$src1, $src2, {sae}, $cc", - (OpNodeSAE (_.VT _.RC:$src1), (_.VT _.RC:$src2), - timm:$cc), - (OpNodeSAE_su (_.VT _.RC:$src1), (_.VT _.RC:$src2), - timm:$cc)>, - EVEX, VVVV, VEX_LIG, EVEX_B, Sched<[sched]>; + defm rrib : AVX512_maskable_cmp<0xC2, MRMSrcReg, _, + (outs _.KRC:$dst), + (ins _.RC:$src1, _.RC:$src2, u8imm:$cc), + "vcmp"#_.Suffix, + "$cc, {sae}, $src2, $src1","$src1, $src2, {sae}, $cc", + (OpNodeSAE (_.VT _.RC:$src1), (_.VT _.RC:$src2), + timm:$cc), + (OpNodeSAE_su (_.VT _.RC:$src1), (_.VT _.RC:$src2), + timm:$cc), 0, "_Int">, + EVEX, VVVV, VEX_LIG, EVEX_B, Sched<[sched]>; let isCodeGenOnly = 1 in { let isCommutable = 1 in @@ -5354,17 +5362,17 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, SDPatternOperator OpNode, SDNode VecNode, X86FoldableSchedWrite sched, bit IsCommutable> { let ExeDomain = _.ExeDomain, Uses = [MXCSR], mayRaiseFPException = 1 in { - defm rr_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rr : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _.RC:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", - (_.VT (VecNode _.RC:$src1, _.RC:$src2))>, + (_.VT (VecNode _.RC:$src1, _.RC:$src2)), "_Int">, Sched<[sched]>; - defm rm_Int : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), + defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", (_.VT (VecNode _.RC:$src1, - (_.ScalarIntMemFrags addr:$src2)))>, + (_.ScalarIntMemFrags addr:$src2))), "_Int">, Sched<[sched.Folded, sched.ReadAfterFold]>; let isCodeGenOnly = 1, Predicates = [HasAVX512] in { def rr : I< opc, MRMSrcReg, (outs _.FRC:$dst), @@ -5387,28 +5395,28 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, multiclass avx512_fp_scalar_round<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, SDNode VecNode, X86FoldableSchedWrite sched> { let ExeDomain = _.ExeDomain, Uses = [MXCSR] in - defm rrb_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rrb : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _.RC:$src2, AVX512RC:$rc), OpcodeStr, "$rc, $src2, $src1", "$src1, $src2, $rc", (VecNode (_.VT _.RC:$src1), (_.VT _.RC:$src2), - (i32 timm:$rc))>, + (i32 timm:$rc)), "_Int">, EVEX_B, EVEX_RC, Sched<[sched]>; } multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, SDPatternOperator OpNode, SDNode VecNode, SDNode SaeNode, X86FoldableSchedWrite sched, bit IsCommutable> { let ExeDomain = _.ExeDomain in { - defm rr_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rr : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _.RC:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", - (_.VT (VecNode _.RC:$src1, _.RC:$src2))>, + (_.VT (VecNode _.RC:$src1, _.RC:$src2)), "_Int">, Sched<[sched]>, SIMD_EXC; - defm rm_Int : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), + defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", (_.VT (VecNode _.RC:$src1, - (_.ScalarIntMemFrags addr:$src2)))>, + (_.ScalarIntMemFrags addr:$src2))), "_Int">, Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC; let isCodeGenOnly = 1, Predicates = [HasAVX512], @@ -5429,10 +5437,10 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, } let Uses = [MXCSR] in - defm rrb_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rrb : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _.RC:$src2), OpcodeStr, "{sae}, $src2, $src1", "$src1, $src2, {sae}", - (SaeNode (_.VT _.RC:$src1), (_.VT _.RC:$src2))>, + (SaeNode (_.VT _.RC:$src1), (_.VT _.RC:$src2)), "_Int">, EVEX_B, Sched<[sched]>; } } @@ -6835,22 +6843,22 @@ defm VFNMSUB132 : avx512_fma3p_132_f<0x9E, "vfnmsub132", X86any_Fnmsub, multiclass avx512_fma3s_common<bits<8> opc, string OpcodeStr, X86VectorVTInfo _, dag RHS_r, dag RHS_m, dag RHS_b, bit MaskOnlyReg> { let Constraints = "$src1 = $dst", hasSideEffects = 0 in { - defm r_Int: AVX512_maskable_3src_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm r: AVX512_maskable_3src_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src2, _.RC:$src3), OpcodeStr, - "$src3, $src2", "$src2, $src3", (null_frag), 1, 1>, + "$src3, $src2", "$src2, $src3", (null_frag), 1, 1, 0, "_Int">, EVEX, VVVV, Sched<[SchedWriteFMA.Scl]>, SIMD_EXC; let mayLoad = 1 in - defm m_Int: AVX512_maskable_3src_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), + defm m: AVX512_maskable_3src_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src2, _.IntScalarMemOp:$src3), OpcodeStr, - "$src3, $src2", "$src2, $src3", (null_frag), 1, 1>, + "$src3, $src2", "$src2, $src3", (null_frag), 1, 1, 0, "_Int">, EVEX, VVVV, Sched<[SchedWriteFMA.Scl.Folded, SchedWriteFMA.Scl.ReadAfterFold, SchedWriteFMA.Scl.ReadAfterFold]>, SIMD_EXC; let Uses = [MXCSR] in - defm rb_Int: AVX512_maskable_3src_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rb: AVX512_maskable_3src_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src2, _.RC:$src3, AVX512RC:$rc), - OpcodeStr, "$rc, $src3, $src2", "$src2, $src3, $rc", (null_frag), 1, 1>, + OpcodeStr, "$rc, $src3, $src2", "$src2, $src3, $rc", (null_frag), 1, 1, 0, "_Int">, EVEX, VVVV, EVEX_B, EVEX_RC, Sched<[SchedWriteFMA.Scl]>; let isCodeGenOnly = 1, isCommutable = 1 in { @@ -6982,7 +6990,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), _.FRC:$src3), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))))))), - (!cast<I>(Prefix#"213"#Suffix#"Zr_Intk") + (!cast<I>(Prefix#"213"#Suffix#"Zrk_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), (_.VT (COPY_TO_REGCLASS _.FRC:$src3, VR128X)))>; @@ -6993,7 +7001,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), (_.ScalarLdFrag addr:$src3)), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))))))), - (!cast<I>(Prefix#"213"#Suffix#"Zm_Intk") + (!cast<I>(Prefix#"213"#Suffix#"Zmk_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), addr:$src3)>; @@ -7002,7 +7010,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (MaskedOp (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), (_.ScalarLdFrag addr:$src3), _.FRC:$src2), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))))))), - (!cast<I>(Prefix#"132"#Suffix#"Zm_Intk") + (!cast<I>(Prefix#"132"#Suffix#"Zmk_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), addr:$src3)>; @@ -7011,7 +7019,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (MaskedOp _.FRC:$src2, _.FRC:$src3, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))))))), - (!cast<I>(Prefix#"231"#Suffix#"Zr_Intk") + (!cast<I>(Prefix#"231"#Suffix#"Zrk_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), (_.VT (COPY_TO_REGCLASS _.FRC:$src3, VR128X)))>; @@ -7021,7 +7029,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (MaskedOp _.FRC:$src2, (_.ScalarLdFrag addr:$src3), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))))))), - (!cast<I>(Prefix#"231"#Suffix#"Zm_Intk") + (!cast<I>(Prefix#"231"#Suffix#"Zmk_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), addr:$src3)>; @@ -7031,7 +7039,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), _.FRC:$src3), (_.EltVT ZeroFP)))))), - (!cast<I>(Prefix#"213"#Suffix#"Zr_Intkz") + (!cast<I>(Prefix#"213"#Suffix#"Zrkz_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), (_.VT (COPY_TO_REGCLASS _.FRC:$src3, VR128X)))>; @@ -7041,7 +7049,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (MaskedOp _.FRC:$src2, _.FRC:$src3, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))), (_.EltVT ZeroFP)))))), - (!cast<I>(Prefix#"231"#Suffix#"Zr_Intkz") + (!cast<I>(Prefix#"231"#Suffix#"Zrkz_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), (_.VT (COPY_TO_REGCLASS _.FRC:$src3, VR128X)))>; @@ -7052,7 +7060,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), (_.ScalarLdFrag addr:$src3)), (_.EltVT ZeroFP)))))), - (!cast<I>(Prefix#"213"#Suffix#"Zm_Intkz") + (!cast<I>(Prefix#"213"#Suffix#"Zmkz_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), addr:$src3)>; @@ -7061,7 +7069,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (MaskedOp (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), _.FRC:$src2, (_.ScalarLdFrag addr:$src3)), (_.EltVT ZeroFP)))))), - (!cast<I>(Prefix#"132"#Suffix#"Zm_Intkz") + (!cast<I>(Prefix#"132"#Suffix#"Zmkz_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), addr:$src3)>; @@ -7070,7 +7078,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (MaskedOp _.FRC:$src2, (_.ScalarLdFrag addr:$src3), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))), (_.EltVT ZeroFP)))))), - (!cast<I>(Prefix#"231"#Suffix#"Zm_Intkz") + (!cast<I>(Prefix#"231"#Suffix#"Zmkz_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), addr:$src3)>; @@ -7097,7 +7105,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), _.FRC:$src3, (i32 timm:$rc)), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))))))), - (!cast<I>(Prefix#"213"#Suffix#"Zrb_Intk") + (!cast<I>(Prefix#"213"#Suffix#"Zrbk_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), (_.VT (COPY_TO_REGCLASS _.FRC:$src3, VR128X)), AVX512RC:$rc)>; @@ -7108,7 +7116,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), (i32 timm:$rc)), (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0)))))))), - (!cast<I>(Prefix#"231"#Suffix#"Zrb_Intk") + (!cast<I>(Prefix#"231"#Suffix#"Zrbk_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), (_.VT (COPY_TO_REGCLASS _.FRC:$src3, VR128X)), AVX512RC:$rc)>; @@ -7119,7 +7127,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), _.FRC:$src3, (i32 timm:$rc)), (_.EltVT ZeroFP)))))), - (!cast<I>(Prefix#"213"#Suffix#"Zrb_Intkz") + (!cast<I>(Prefix#"213"#Suffix#"Zrbkz_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), (_.VT (COPY_TO_REGCLASS _.FRC:$src3, VR128X)), AVX512RC:$rc)>; @@ -7130,7 +7138,7 @@ multiclass avx512_scalar_fma_patterns<SDPatternOperator Op, SDNode MaskedOp, (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), (i32 timm:$rc)), (_.EltVT ZeroFP)))))), - (!cast<I>(Prefix#"231"#Suffix#"Zrb_Intkz") + (!cast<I>(Prefix#"231"#Suffix#"Zrbkz_Int") VR128X:$src1, VK1WM:$mask, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)), (_.VT (COPY_TO_REGCLASS _.FRC:$src3, VR128X)), AVX512RC:$rc)>; @@ -7628,17 +7636,17 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in multiclass avx512_cvt_fp_scalar<bits<8> opc, string OpcodeStr, X86VectorVTInfo _, X86VectorVTInfo _Src, SDNode OpNode, X86FoldableSchedWrite sched> { - defm rr_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rr : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _Src.RC:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", (_.VT (OpNode (_.VT _.RC:$src1), - (_Src.VT _Src.RC:$src2)))>, + (_Src.VT _Src.RC:$src2))), "_Int">, EVEX, VVVV, VEX_LIG, Sched<[sched]>; - defm rm_Int : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), + defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _Src.IntScalarMemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", (_.VT (OpNode (_.VT _.RC:$src1), - (_Src.ScalarIntMemFrags addr:$src2)))>, + (_Src.ScalarIntMemFrags addr:$src2))), "_Int">, EVEX, VVVV, VEX_LIG, Sched<[sched.Folded, sched.ReadAfterFold]>; @@ -7660,11 +7668,11 @@ multiclass avx512_cvt_fp_sae_scalar<bits<8> opc, string OpcodeStr, X86VectorVTIn X86VectorVTInfo _Src, SDNode OpNodeSAE, X86FoldableSchedWrite sched> { let Uses = [MXCSR] in - defm rrb_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rrb : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _Src.RC:$src2), OpcodeStr, "{sae}, $src2, $src1", "$src1, $src2, {sae}", (_.VT (OpNodeSAE (_.VT _.RC:$src1), - (_Src.VT _Src.RC:$src2)))>, + (_Src.VT _Src.RC:$src2))), "_Int">, EVEX, VVVV, VEX_LIG, EVEX_B, Sched<[sched]>; } @@ -7673,11 +7681,11 @@ multiclass avx512_cvt_fp_rc_scalar<bits<8> opc, string OpcodeStr, X86VectorVTInf X86VectorVTInfo _Src, SDNode OpNodeRnd, X86FoldableSchedWrite sched> { let Uses = [MXCSR] in - defm rrb_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rrb : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _Src.RC:$src2, AVX512RC:$rc), OpcodeStr, "$rc, $src2, $src1", "$src1, $src2, $rc", (_.VT (OpNodeRnd (_.VT _.RC:$src1), - (_Src.VT _Src.RC:$src2), (i32 timm:$rc)))>, + (_Src.VT _Src.RC:$src2), (i32 timm:$rc))), "_Int">, EVEX, VVVV, VEX_LIG, Sched<[sched]>, EVEX_B, EVEX_RC; } @@ -9531,25 +9539,25 @@ multiclass avx512_sqrt_packed_all_round<bits<8> opc, string OpcodeStr, multiclass avx512_sqrt_scalar<bits<8> opc, string OpcodeStr, X86FoldableSchedWrite sched, X86VectorVTInfo _, string Name, Predicate prd = HasAVX512> { let ExeDomain = _.ExeDomain, Predicates = [prd] in { - defm r_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm r : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _.RC:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", (X86fsqrts (_.VT _.RC:$src1), - (_.VT _.RC:$src2))>, + (_.VT _.RC:$src2)), "_Int">, Sched<[sched]>, SIMD_EXC; - defm m_Int : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), + defm m : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", (X86fsqrts (_.VT _.RC:$src1), - (_.ScalarIntMemFrags addr:$src2))>, + (_.ScalarIntMemFrags addr:$src2)), "_Int">, Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC; let Uses = [MXCSR] in - defm rb_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rb : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _.RC:$src2, AVX512RC:$rc), OpcodeStr, "$rc, $src2, $src1", "$src1, $src2, $rc", (X86fsqrtRnds (_.VT _.RC:$src1), (_.VT _.RC:$src2), - (i32 timm:$rc))>, + (i32 timm:$rc)), "_Int">, EVEX_B, EVEX_RC, Sched<[sched]>; let isCodeGenOnly = 1, hasSideEffects = 0 in { @@ -9596,27 +9604,27 @@ defm VSQRT : avx512_sqrt_scalar_all<0x51, "vsqrt", SchedWriteFSqrtSizes>, VEX_LI multiclass avx512_rndscale_scalar<bits<8> opc, string OpcodeStr, X86FoldableSchedWrite sched, X86VectorVTInfo _> { let ExeDomain = _.ExeDomain in { - defm rri_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rri : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _.RC:$src2, i32u8imm:$src3), OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3", (_.VT (X86RndScales (_.VT _.RC:$src1), (_.VT _.RC:$src2), - (i32 timm:$src3)))>, + (i32 timm:$src3))), "_Int">, Sched<[sched]>, SIMD_EXC; let Uses = [MXCSR] in - defm rrib_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), + defm rrib : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst), (ins _.RC:$src1, _.RC:$src2, i32u8imm:$src3), OpcodeStr, "$src3, {sae}, $src2, $src1", "$src1, $src2, {sae}, $src3", (_.VT (X86RndScalesSAE (_.VT _.RC:$src1), (_.VT _.RC:$src2), - (i32 timm:$src3)))>, EVEX_B, + (i32 timm:$src3))), "_Int">, EVEX_B, Sched<[sched]>; - defm rmi_Int : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), + defm rmi : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2, i32u8imm:$src3), OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3", (_.VT (X86RndScales _.RC:$src1, - (_.ScalarIntMemFrags addr:$src2), (i32 timm:$src3)))>, + (_.ScalarIntMemFrags addr:$src2), (i32 timm:$src3))), "_Int">, Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC; let isCodeGenOnly = 1, hasSideEffects = 0, Predicates = [HasAVX512] in { @@ -9669,13 +9677,13 @@ multiclass avx512_masked_scalar<SDNode OpNode, string OpcPrefix, SDNode Move, def : Pat<(Move _.VT:$src1, (scalar_to_vector (X86selects_mask Mask, (OpNode (extractelt _.VT:$src2, (iPTR 0))), (extractelt _.VT:$dst, (iPTR 0))))), - (!cast<Instruction>("V"#OpcPrefix#r_Intk) + (!cast<Instruction>("V"#OpcPrefix#rk_Int) _.VT:$dst, OutMask, _.VT:$src2, _.VT:$src1)>; def : Pat<(Move _.VT:$src1, (scalar_to_vector (X86selects_mask Mask, (OpNode (extractelt _.VT:$src2, (iPTR 0))), ZeroFP))), - (!cast<Instruction>("V"#OpcPrefix#r_Intkz) + (!cast<Instruction>("V"#OpcPrefix#rkz_Int) OutMask, _.VT:$src2, _.VT:$src1)>; } } @@ -12174,7 +12182,7 @@ multiclass AVX512_scalar_math_fp_patterns<SDPatternOperator Op, SDNode MaskedOp, (extractelt (_.VT VR128X:$src1), (iPTR 0))), _.FRC:$src2), _.FRC:$src0))), - (!cast<Instruction>("V"#OpcPrefix#"Zrr_Intk") + (!cast<Instruction>("V"#OpcPrefix#"Zrrk_Int") (_.VT (COPY_TO_REGCLASS _.FRC:$src0, VR128X)), VK1WM:$mask, _.VT:$src1, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)))>; @@ -12185,7 +12193,7 @@ multiclass AVX512_scalar_math_fp_patterns<SDPatternOperator Op, SDNode MaskedOp, (extractelt (_.VT VR128X:$src1), (iPTR 0))), (_.ScalarLdFrag addr:$src2)), _.FRC:$src0))), - (!cast<Instruction>("V"#OpcPrefix#"Zrm_Intk") + (!cast<Instruction>("V"#OpcPrefix#"Zrmk_Int") (_.VT (COPY_TO_REGCLASS _.FRC:$src0, VR128X)), VK1WM:$mask, _.VT:$src1, addr:$src2)>; @@ -12196,7 +12204,7 @@ multiclass AVX512_scalar_math_fp_patterns<SDPatternOperator Op, SDNode MaskedOp, (MaskedOp (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), _.FRC:$src2), (_.EltVT ZeroFP)))), - (!cast<I>("V"#OpcPrefix#"Zrr_Intkz") + (!cast<I>("V"#OpcPrefix#"Zrrkz_Int") VK1WM:$mask, _.VT:$src1, (_.VT (COPY_TO_REGCLASS _.FRC:$src2, VR128X)))>; def : Pat<(MoveNode (_.VT VR128X:$src1), @@ -12205,7 +12213,7 @@ multiclass AVX512_scalar_math_fp_patterns<SDPatternOperator Op, SDNode MaskedOp, (MaskedOp (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), (_.ScalarLdFrag addr:$src2)), (_.EltVT ZeroFP)))), - (!cast<I>("V"#OpcPrefix#"Zrm_Intkz") VK1WM:$mask, _.VT:$src1, addr:$src2)>; + (!cast<I>("V"#OpcPrefix#"Zrmkz_Int") VK1WM:$mask, _.VT:$src1, addr:$src2)>; } } diff --git a/llvm/lib/Target/X86/X86InstrFMA3Info.cpp b/llvm/lib/Target/X86/X86InstrFMA3Info.cpp index 090ec68..0da4857 100644 --- a/llvm/lib/Target/X86/X86InstrFMA3Info.cpp +++ b/llvm/lib/Target/X86/X86InstrFMA3Info.cpp @@ -27,6 +27,11 @@ using namespace llvm; FMA3GROUP(Name, Suf##k, Attrs | X86InstrFMA3Group::KMergeMasked) \ FMA3GROUP(Name, Suf##kz, Attrs | X86InstrFMA3Group::KZeroMasked) +#define FMA3GROUP_MASKED_INT(Name, Suf, Attrs) \ + FMA3GROUP(Name, Suf##_Int, Attrs) \ + FMA3GROUP(Name, Suf##k_Int, Attrs | X86InstrFMA3Group::KMergeMasked) \ + FMA3GROUP(Name, Suf##kz_Int, Attrs | X86InstrFMA3Group::KZeroMasked) + #define FMA3GROUP_PACKED_WIDTHS_Z(Name, Suf, Attrs) \ FMA3GROUP_MASKED(Name, Suf##Z128m, Attrs) \ FMA3GROUP_MASKED(Name, Suf##Z128r, Attrs) \ @@ -52,9 +57,9 @@ using namespace llvm; #define FMA3GROUP_SCALAR_WIDTHS_Z(Name, Suf, Attrs) \ FMA3GROUP(Name, Suf##Zm, Attrs) \ - FMA3GROUP_MASKED(Name, Suf##Zm_Int, Attrs | X86InstrFMA3Group::Intrinsic) \ + FMA3GROUP_MASKED_INT(Name, Suf##Zm, Attrs | X86InstrFMA3Group::Intrinsic) \ FMA3GROUP(Name, Suf##Zr, Attrs) \ - FMA3GROUP_MASKED(Name, Suf##Zr_Int, Attrs | X86InstrFMA3Group::Intrinsic) \ + FMA3GROUP_MASKED_INT(Name, Suf##Zr, Attrs | X86InstrFMA3Group::Intrinsic) \ #define FMA3GROUP_SCALAR_WIDTHS_ALL(Name, Suf, Attrs) \ FMA3GROUP_SCALAR_WIDTHS_Z(Name, Suf, Attrs) \ @@ -108,11 +113,11 @@ static const X86InstrFMA3Group Groups[] = { #define FMA3GROUP_SCALAR_AVX512_ROUND(Name, Suf, Attrs) \ FMA3GROUP(Name, SDZ##Suf, Attrs) \ - FMA3GROUP_MASKED(Name, SDZ##Suf##_Int, Attrs) \ + FMA3GROUP_MASKED_INT(Name, SDZ##Suf, Attrs) \ FMA3GROUP(Name, SHZ##Suf, Attrs) \ - FMA3GROUP_MASKED(Name, SHZ##Suf##_Int, Attrs) \ + FMA3GROUP_MASKED_INT(Name, SHZ##Suf, Attrs) \ FMA3GROUP(Name, SSZ##Suf, Attrs) \ - FMA3GROUP_MASKED(Name, SSZ##Suf##_Int, Attrs) + FMA3GROUP_MASKED_INT(Name, SSZ##Suf, Attrs) static const X86InstrFMA3Group BroadcastGroups[] = { FMA3GROUP_PACKED_AVX512_ALL(VFMADD, mb, 0) diff --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td index f6231b7..af0267a 100644 --- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td +++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td @@ -390,6 +390,13 @@ def SDTFmaRound : SDTypeProfile<1, 4, [SDTCisSameAs<0,1>, SDTCisSameAs<1,2>, SDTCisSameAs<1,3>, SDTCisFP<0>, SDTCisVT<4, i32>]>; +def SDTFPToxIntSatOp + : SDTypeProfile<1, + 1, [SDTCisVec<0>, SDTCisVec<1>, SDTCisInt<0>, SDTCisFP<1>]>; + +def X86fp2sisat : SDNode<"X86ISD::FP_TO_SINT_SAT", SDTFPToxIntSatOp>; +def X86fp2uisat : SDNode<"X86ISD::FP_TO_UINT_SAT", SDTFPToxIntSatOp>; + def X86PAlignr : SDNode<"X86ISD::PALIGNR", SDTypeProfile<1, 3, [SDTCVecEltisVT<0, i8>, SDTCisSameAs<0,1>, diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index 5a6ea11..30a5161 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -7646,8 +7646,8 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::CVTSS2SDrr_Int: case X86::VCVTSS2SDrr_Int: case X86::VCVTSS2SDZrr_Int: - case X86::VCVTSS2SDZrr_Intk: - case X86::VCVTSS2SDZrr_Intkz: + case X86::VCVTSS2SDZrrk_Int: + case X86::VCVTSS2SDZrrkz_Int: case X86::CVTSS2SIrr_Int: case X86::CVTSS2SI64rr_Int: case X86::VCVTSS2SIrr_Int: @@ -7700,21 +7700,21 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::SUBSSrr_Int: case X86::VSUBSSrr_Int: case X86::VSUBSSZrr_Int: - case X86::VADDSSZrr_Intk: - case X86::VADDSSZrr_Intkz: - case X86::VCMPSSZrri_Intk: - case X86::VDIVSSZrr_Intk: - case X86::VDIVSSZrr_Intkz: - case X86::VMAXSSZrr_Intk: - case X86::VMAXSSZrr_Intkz: - case X86::VMINSSZrr_Intk: - case X86::VMINSSZrr_Intkz: - case X86::VMULSSZrr_Intk: - case X86::VMULSSZrr_Intkz: - case X86::VSQRTSSZr_Intk: - case X86::VSQRTSSZr_Intkz: - case X86::VSUBSSZrr_Intk: - case X86::VSUBSSZrr_Intkz: + case X86::VADDSSZrrk_Int: + case X86::VADDSSZrrkz_Int: + case X86::VCMPSSZrrik_Int: + case X86::VDIVSSZrrk_Int: + case X86::VDIVSSZrrkz_Int: + case X86::VMAXSSZrrk_Int: + case X86::VMAXSSZrrkz_Int: + case X86::VMINSSZrrk_Int: + case X86::VMINSSZrrkz_Int: + case X86::VMULSSZrrk_Int: + case X86::VMULSSZrrkz_Int: + case X86::VSQRTSSZrk_Int: + case X86::VSQRTSSZrkz_Int: + case X86::VSUBSSZrrk_Int: + case X86::VSUBSSZrrkz_Int: case X86::VFMADDSS4rr_Int: case X86::VFNMADDSS4rr_Int: case X86::VFMSUBSS4rr_Int: @@ -7743,30 +7743,30 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::VFNMSUB213SSZr_Int: case X86::VFMSUB231SSZr_Int: case X86::VFNMSUB231SSZr_Int: - case X86::VFMADD132SSZr_Intk: - case X86::VFNMADD132SSZr_Intk: - case X86::VFMADD213SSZr_Intk: - case X86::VFNMADD213SSZr_Intk: - case X86::VFMADD231SSZr_Intk: - case X86::VFNMADD231SSZr_Intk: - case X86::VFMSUB132SSZr_Intk: - case X86::VFNMSUB132SSZr_Intk: - case X86::VFMSUB213SSZr_Intk: - case X86::VFNMSUB213SSZr_Intk: - case X86::VFMSUB231SSZr_Intk: - case X86::VFNMSUB231SSZr_Intk: - case X86::VFMADD132SSZr_Intkz: - case X86::VFNMADD132SSZr_Intkz: - case X86::VFMADD213SSZr_Intkz: - case X86::VFNMADD213SSZr_Intkz: - case X86::VFMADD231SSZr_Intkz: - case X86::VFNMADD231SSZr_Intkz: - case X86::VFMSUB132SSZr_Intkz: - case X86::VFNMSUB132SSZr_Intkz: - case X86::VFMSUB213SSZr_Intkz: - case X86::VFNMSUB213SSZr_Intkz: - case X86::VFMSUB231SSZr_Intkz: - case X86::VFNMSUB231SSZr_Intkz: + case X86::VFMADD132SSZrk_Int: + case X86::VFNMADD132SSZrk_Int: + case X86::VFMADD213SSZrk_Int: + case X86::VFNMADD213SSZrk_Int: + case X86::VFMADD231SSZrk_Int: + case X86::VFNMADD231SSZrk_Int: + case X86::VFMSUB132SSZrk_Int: + case X86::VFNMSUB132SSZrk_Int: + case X86::VFMSUB213SSZrk_Int: + case X86::VFNMSUB213SSZrk_Int: + case X86::VFMSUB231SSZrk_Int: + case X86::VFNMSUB231SSZrk_Int: + case X86::VFMADD132SSZrkz_Int: + case X86::VFNMADD132SSZrkz_Int: + case X86::VFMADD213SSZrkz_Int: + case X86::VFNMADD213SSZrkz_Int: + case X86::VFMADD231SSZrkz_Int: + case X86::VFNMADD231SSZrkz_Int: + case X86::VFMSUB132SSZrkz_Int: + case X86::VFNMSUB132SSZrkz_Int: + case X86::VFMSUB213SSZrkz_Int: + case X86::VFNMSUB213SSZrkz_Int: + case X86::VFMSUB231SSZrkz_Int: + case X86::VFNMSUB231SSZrkz_Int: case X86::VFIXUPIMMSSZrri: case X86::VFIXUPIMMSSZrrik: case X86::VFIXUPIMMSSZrrikz: @@ -7791,8 +7791,8 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::VREDUCESSZrrik: case X86::VREDUCESSZrrikz: case X86::VRNDSCALESSZrri_Int: - case X86::VRNDSCALESSZrri_Intk: - case X86::VRNDSCALESSZrri_Intkz: + case X86::VRNDSCALESSZrrik_Int: + case X86::VRNDSCALESSZrrikz_Int: case X86::VRSQRT14SSZrr: case X86::VRSQRT14SSZrrk: case X86::VRSQRT14SSZrrkz: @@ -7819,8 +7819,8 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::CVTSD2SSrr_Int: case X86::VCVTSD2SSrr_Int: case X86::VCVTSD2SSZrr_Int: - case X86::VCVTSD2SSZrr_Intk: - case X86::VCVTSD2SSZrr_Intkz: + case X86::VCVTSD2SSZrrk_Int: + case X86::VCVTSD2SSZrrkz_Int: case X86::CVTSD2SIrr_Int: case X86::CVTSD2SI64rr_Int: case X86::VCVTSD2SIrr_Int: @@ -7869,21 +7869,21 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::SUBSDrr_Int: case X86::VSUBSDrr_Int: case X86::VSUBSDZrr_Int: - case X86::VADDSDZrr_Intk: - case X86::VADDSDZrr_Intkz: - case X86::VCMPSDZrri_Intk: - case X86::VDIVSDZrr_Intk: - case X86::VDIVSDZrr_Intkz: - case X86::VMAXSDZrr_Intk: - case X86::VMAXSDZrr_Intkz: - case X86::VMINSDZrr_Intk: - case X86::VMINSDZrr_Intkz: - case X86::VMULSDZrr_Intk: - case X86::VMULSDZrr_Intkz: - case X86::VSQRTSDZr_Intk: - case X86::VSQRTSDZr_Intkz: - case X86::VSUBSDZrr_Intk: - case X86::VSUBSDZrr_Intkz: + case X86::VADDSDZrrk_Int: + case X86::VADDSDZrrkz_Int: + case X86::VCMPSDZrrik_Int: + case X86::VDIVSDZrrk_Int: + case X86::VDIVSDZrrkz_Int: + case X86::VMAXSDZrrk_Int: + case X86::VMAXSDZrrkz_Int: + case X86::VMINSDZrrk_Int: + case X86::VMINSDZrrkz_Int: + case X86::VMULSDZrrk_Int: + case X86::VMULSDZrrkz_Int: + case X86::VSQRTSDZrk_Int: + case X86::VSQRTSDZrkz_Int: + case X86::VSUBSDZrrk_Int: + case X86::VSUBSDZrrkz_Int: case X86::VFMADDSD4rr_Int: case X86::VFNMADDSD4rr_Int: case X86::VFMSUBSD4rr_Int: @@ -7912,30 +7912,30 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::VFNMSUB213SDZr_Int: case X86::VFMSUB231SDZr_Int: case X86::VFNMSUB231SDZr_Int: - case X86::VFMADD132SDZr_Intk: - case X86::VFNMADD132SDZr_Intk: - case X86::VFMADD213SDZr_Intk: - case X86::VFNMADD213SDZr_Intk: - case X86::VFMADD231SDZr_Intk: - case X86::VFNMADD231SDZr_Intk: - case X86::VFMSUB132SDZr_Intk: - case X86::VFNMSUB132SDZr_Intk: - case X86::VFMSUB213SDZr_Intk: - case X86::VFNMSUB213SDZr_Intk: - case X86::VFMSUB231SDZr_Intk: - case X86::VFNMSUB231SDZr_Intk: - case X86::VFMADD132SDZr_Intkz: - case X86::VFNMADD132SDZr_Intkz: - case X86::VFMADD213SDZr_Intkz: - case X86::VFNMADD213SDZr_Intkz: - case X86::VFMADD231SDZr_Intkz: - case X86::VFNMADD231SDZr_Intkz: - case X86::VFMSUB132SDZr_Intkz: - case X86::VFNMSUB132SDZr_Intkz: - case X86::VFMSUB213SDZr_Intkz: - case X86::VFNMSUB213SDZr_Intkz: - case X86::VFMSUB231SDZr_Intkz: - case X86::VFNMSUB231SDZr_Intkz: + case X86::VFMADD132SDZrk_Int: + case X86::VFNMADD132SDZrk_Int: + case X86::VFMADD213SDZrk_Int: + case X86::VFNMADD213SDZrk_Int: + case X86::VFMADD231SDZrk_Int: + case X86::VFNMADD231SDZrk_Int: + case X86::VFMSUB132SDZrk_Int: + case X86::VFNMSUB132SDZrk_Int: + case X86::VFMSUB213SDZrk_Int: + case X86::VFNMSUB213SDZrk_Int: + case X86::VFMSUB231SDZrk_Int: + case X86::VFNMSUB231SDZrk_Int: + case X86::VFMADD132SDZrkz_Int: + case X86::VFNMADD132SDZrkz_Int: + case X86::VFMADD213SDZrkz_Int: + case X86::VFNMADD213SDZrkz_Int: + case X86::VFMADD231SDZrkz_Int: + case X86::VFNMADD231SDZrkz_Int: + case X86::VFMSUB132SDZrkz_Int: + case X86::VFNMSUB132SDZrkz_Int: + case X86::VFMSUB213SDZrkz_Int: + case X86::VFNMSUB213SDZrkz_Int: + case X86::VFMSUB231SDZrkz_Int: + case X86::VFNMSUB231SDZrkz_Int: case X86::VFIXUPIMMSDZrri: case X86::VFIXUPIMMSDZrrik: case X86::VFIXUPIMMSDZrrikz: @@ -7960,8 +7960,8 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::VREDUCESDZrrik: case X86::VREDUCESDZrrikz: case X86::VRNDSCALESDZrri_Int: - case X86::VRNDSCALESDZrri_Intk: - case X86::VRNDSCALESDZrri_Intkz: + case X86::VRNDSCALESDZrrik_Int: + case X86::VRNDSCALESDZrrikz_Int: case X86::VRSQRT14SDZrr: case X86::VRSQRT14SDZrrk: case X86::VRSQRT14SDZrrkz: @@ -7989,19 +7989,19 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::VMINSHZrr_Int: case X86::VMULSHZrr_Int: case X86::VSUBSHZrr_Int: - case X86::VADDSHZrr_Intk: - case X86::VADDSHZrr_Intkz: - case X86::VCMPSHZrri_Intk: - case X86::VDIVSHZrr_Intk: - case X86::VDIVSHZrr_Intkz: - case X86::VMAXSHZrr_Intk: - case X86::VMAXSHZrr_Intkz: - case X86::VMINSHZrr_Intk: - case X86::VMINSHZrr_Intkz: - case X86::VMULSHZrr_Intk: - case X86::VMULSHZrr_Intkz: - case X86::VSUBSHZrr_Intk: - case X86::VSUBSHZrr_Intkz: + case X86::VADDSHZrrk_Int: + case X86::VADDSHZrrkz_Int: + case X86::VCMPSHZrrik_Int: + case X86::VDIVSHZrrk_Int: + case X86::VDIVSHZrrkz_Int: + case X86::VMAXSHZrrk_Int: + case X86::VMAXSHZrrkz_Int: + case X86::VMINSHZrrk_Int: + case X86::VMINSHZrrkz_Int: + case X86::VMULSHZrrk_Int: + case X86::VMULSHZrrkz_Int: + case X86::VSUBSHZrrk_Int: + case X86::VSUBSHZrrkz_Int: case X86::VFMADD132SHZr_Int: case X86::VFNMADD132SHZr_Int: case X86::VFMADD213SHZr_Int: @@ -8014,30 +8014,30 @@ static bool isNonFoldablePartialRegisterLoad(const MachineInstr &LoadMI, case X86::VFNMSUB213SHZr_Int: case X86::VFMSUB231SHZr_Int: case X86::VFNMSUB231SHZr_Int: - case X86::VFMADD132SHZr_Intk: - case X86::VFNMADD132SHZr_Intk: - case X86::VFMADD213SHZr_Intk: - case X86::VFNMADD213SHZr_Intk: - case X86::VFMADD231SHZr_Intk: - case X86::VFNMADD231SHZr_Intk: - case X86::VFMSUB132SHZr_Intk: - case X86::VFNMSUB132SHZr_Intk: - case X86::VFMSUB213SHZr_Intk: - case X86::VFNMSUB213SHZr_Intk: - case X86::VFMSUB231SHZr_Intk: - case X86::VFNMSUB231SHZr_Intk: - case X86::VFMADD132SHZr_Intkz: - case X86::VFNMADD132SHZr_Intkz: - case X86::VFMADD213SHZr_Intkz: - case X86::VFNMADD213SHZr_Intkz: - case X86::VFMADD231SHZr_Intkz: - case X86::VFNMADD231SHZr_Intkz: - case X86::VFMSUB132SHZr_Intkz: - case X86::VFNMSUB132SHZr_Intkz: - case X86::VFMSUB213SHZr_Intkz: - case X86::VFNMSUB213SHZr_Intkz: - case X86::VFMSUB231SHZr_Intkz: - case X86::VFNMSUB231SHZr_Intkz: + case X86::VFMADD132SHZrk_Int: + case X86::VFNMADD132SHZrk_Int: + case X86::VFMADD213SHZrk_Int: + case X86::VFNMADD213SHZrk_Int: + case X86::VFMADD231SHZrk_Int: + case X86::VFNMADD231SHZrk_Int: + case X86::VFMSUB132SHZrk_Int: + case X86::VFNMSUB132SHZrk_Int: + case X86::VFMSUB213SHZrk_Int: + case X86::VFNMSUB213SHZrk_Int: + case X86::VFMSUB231SHZrk_Int: + case X86::VFNMSUB231SHZrk_Int: + case X86::VFMADD132SHZrkz_Int: + case X86::VFNMADD132SHZrkz_Int: + case X86::VFMADD213SHZrkz_Int: + case X86::VFNMADD213SHZrkz_Int: + case X86::VFMADD231SHZrkz_Int: + case X86::VFNMADD231SHZrkz_Int: + case X86::VFMSUB132SHZrkz_Int: + case X86::VFNMSUB132SHZrkz_Int: + case X86::VFMSUB213SHZrkz_Int: + case X86::VFNMSUB213SHZrkz_Int: + case X86::VFMSUB231SHZrkz_Int: + case X86::VFNMSUB231SHZrkz_Int: return false; default: return true; @@ -9489,25 +9489,25 @@ bool X86InstrInfo::isHighLatencyDef(int opc) const { case X86::VDIVSDZrm: case X86::VDIVSDZrr: case X86::VDIVSDZrm_Int: - case X86::VDIVSDZrm_Intk: - case X86::VDIVSDZrm_Intkz: + case X86::VDIVSDZrmk_Int: + case X86::VDIVSDZrmkz_Int: case X86::VDIVSDZrr_Int: - case X86::VDIVSDZrr_Intk: - case X86::VDIVSDZrr_Intkz: + case X86::VDIVSDZrrk_Int: + case X86::VDIVSDZrrkz_Int: case X86::VDIVSDZrrb_Int: - case X86::VDIVSDZrrb_Intk: - case X86::VDIVSDZrrb_Intkz: + case X86::VDIVSDZrrbk_Int: + case X86::VDIVSDZrrbkz_Int: case X86::VDIVSSZrm: case X86::VDIVSSZrr: case X86::VDIVSSZrm_Int: - case X86::VDIVSSZrm_Intk: - case X86::VDIVSSZrm_Intkz: + case X86::VDIVSSZrmk_Int: + case X86::VDIVSSZrmkz_Int: case X86::VDIVSSZrr_Int: - case X86::VDIVSSZrr_Intk: - case X86::VDIVSSZrr_Intkz: + case X86::VDIVSSZrrk_Int: + case X86::VDIVSSZrrkz_Int: case X86::VDIVSSZrrb_Int: - case X86::VDIVSSZrrb_Intk: - case X86::VDIVSSZrrb_Intkz: + case X86::VDIVSSZrrbk_Int: + case X86::VDIVSSZrrbkz_Int: case X86::VSQRTPDZ128m: case X86::VSQRTPDZ128mb: case X86::VSQRTPDZ128mbk: @@ -9570,26 +9570,26 @@ bool X86InstrInfo::isHighLatencyDef(int opc) const { case X86::VSQRTPSZrkz: case X86::VSQRTSDZm: case X86::VSQRTSDZm_Int: - case X86::VSQRTSDZm_Intk: - case X86::VSQRTSDZm_Intkz: + case X86::VSQRTSDZmk_Int: + case X86::VSQRTSDZmkz_Int: case X86::VSQRTSDZr: case X86::VSQRTSDZr_Int: - case X86::VSQRTSDZr_Intk: - case X86::VSQRTSDZr_Intkz: + case X86::VSQRTSDZrk_Int: + case X86::VSQRTSDZrkz_Int: case X86::VSQRTSDZrb_Int: - case X86::VSQRTSDZrb_Intk: - case X86::VSQRTSDZrb_Intkz: + case X86::VSQRTSDZrbk_Int: + case X86::VSQRTSDZrbkz_Int: case X86::VSQRTSSZm: case X86::VSQRTSSZm_Int: - case X86::VSQRTSSZm_Intk: - case X86::VSQRTSSZm_Intkz: + case X86::VSQRTSSZmk_Int: + case X86::VSQRTSSZmkz_Int: case X86::VSQRTSSZr: case X86::VSQRTSSZr_Int: - case X86::VSQRTSSZr_Intk: - case X86::VSQRTSSZr_Intkz: + case X86::VSQRTSSZrk_Int: + case X86::VSQRTSSZrkz_Int: case X86::VSQRTSSZrb_Int: - case X86::VSQRTSSZrb_Intk: - case X86::VSQRTSSZrb_Intkz: + case X86::VSQRTSSZrbk_Int: + case X86::VSQRTSSZrbkz_Int: case X86::VGATHERDPDYrm: case X86::VGATHERDPDZ128rm: diff --git a/llvm/lib/Target/X86/X86LoadValueInjectionRetHardening.cpp b/llvm/lib/Target/X86/X86LoadValueInjectionRetHardening.cpp index 3b370d8..64728a2 100644 --- a/llvm/lib/Target/X86/X86LoadValueInjectionRetHardening.cpp +++ b/llvm/lib/Target/X86/X86LoadValueInjectionRetHardening.cpp @@ -57,8 +57,6 @@ char X86LoadValueInjectionRetHardeningPass::ID = 0; bool X86LoadValueInjectionRetHardeningPass::runOnMachineFunction( MachineFunction &MF) { - LLVM_DEBUG(dbgs() << "***** " << getPassName() << " : " << MF.getName() - << " *****\n"); const X86Subtarget *Subtarget = &MF.getSubtarget<X86Subtarget>(); if (!Subtarget->useLVIControlFlowIntegrity() || !Subtarget->is64Bit()) return false; // FIXME: support 32-bit @@ -68,6 +66,8 @@ bool X86LoadValueInjectionRetHardeningPass::runOnMachineFunction( if (!F.hasOptNone() && skipFunction(F)) return false; + LLVM_DEBUG(dbgs() << "***** " << getPassName() << " : " << MF.getName() + << " *****\n"); ++NumFunctionsConsidered; const X86RegisterInfo *TRI = Subtarget->getRegisterInfo(); const X86InstrInfo *TII = Subtarget->getInstrInfo(); diff --git a/llvm/lib/Target/X86/X86SchedSapphireRapids.td b/llvm/lib/Target/X86/X86SchedSapphireRapids.td index e04ff68..4f0d366 100644 --- a/llvm/lib/Target/X86/X86SchedSapphireRapids.td +++ b/llvm/lib/Target/X86/X86SchedSapphireRapids.td @@ -669,7 +669,7 @@ def : InstRW<[SPRWriteResGroup12], (instregex "^ADD_F(P?)rST0$", "^VALIGN(D|Q)Z256rri((k|kz)?)$", "^VCMPP(D|H|S)Z(128|256)rri(k?)$", "^VCMPS(D|H|S)Zrri$", - "^VCMPS(D|H|S)Zrr(b?)i_Int(k?)$", + "^VCMPS(D|H|S)Zrr(b?)i(k?)_Int$", "^VFPCLASSP(D|H|S)Z(128|256)ri(k?)$", "^VFPCLASSS(D|H|S)Zri(k?)$", "^VPACK(S|U)S(DW|WB)Yrr$", @@ -977,7 +977,7 @@ def SPRWriteResGroup49 : SchedWriteRes<[SPRPort00, SPRPort02_03_10]> { let NumMicroOps = 2; } def : InstRW<[SPRWriteResGroup49], (instregex "^DIV_F(32|64)m$")>; -def : InstRW<[SPRWriteResGroup49, ReadAfterVecLd], (instregex "^VSQRTSHZm_Int((k|kz)?)$")>; +def : InstRW<[SPRWriteResGroup49, ReadAfterVecLd], (instregex "^VSQRTSHZm((k|kz)?)_Int$")>; def : InstRW<[SPRWriteResGroup49, ReadAfterVecLd], (instrs VSQRTSHZm)>; def SPRWriteResGroup50 : SchedWriteRes<[SPRPort00, SPRPort02_03_10, SPRPort05]> { @@ -1166,11 +1166,11 @@ def : InstRW<[SPRWriteResGroup73, ReadAfterVecXLd], (instregex "^(V?)GF2P8AFFINE def : InstRW<[SPRWriteResGroup73, ReadAfterVecXLd], (instrs VGETEXPPHZ128mbkz, VGF2P8MULBZ128rm)>; def : InstRW<[SPRWriteResGroup73, ReadAfterVecLd], (instregex "^V(ADD|SUB)SHZrm$", - "^V(ADD|SUB)SHZrm_Int((k|kz)?)$", + "^V(ADD|SUB)SHZrm((k|kz)?)_Int$", "^VCVTSH2SSZrm((_Int)?)$", "^VM(AX|IN)CSHZrm$", "^VM(AX|IN|UL)SHZrm$", - "^VM(AX|IN|UL)SHZrm_Int((k|kz)?)$")>; + "^VM(AX|IN|UL)SHZrm((k|kz)?)_Int$")>; def : InstRW<[SPRWriteResGroup73, ReadAfterVecYLd], (instregex "^VGF2P8AFFINE((INV)?)QBYrmi$", "^VGF2P8AFFINE((INV)?)QBZ256rm(b?)i$", "^VGF2P8MULB(Y|Z256)rm$")>; @@ -1181,7 +1181,7 @@ def : InstRW<[SPRWriteResGroup73, ReadAfterVecXLd, ReadAfterVecXLd], (instregex "^VFMSUBADD(132|213|231)PHZ128m((b|k|bk|kz)?)$", "^VFMSUBADD(132|213|231)PHZ128mbkz$")>; def : InstRW<[SPRWriteResGroup73, ReadAfterVecLd, ReadAfterVecLd], (instregex "^VF(N?)M(ADD|SUB)(132|213|231)SHZm$", - "^VF(N?)M(ADD|SUB)(132|213|231)SHZm_Int((k|kz)?)$")>; + "^VF(N?)M(ADD|SUB)(132|213|231)SHZm((k|kz)?)_Int$")>; def : InstRW<[SPRWriteResGroup73, ReadAfterVecYLd, ReadAfterVecYLd], (instregex "^VPMADD52(H|L)UQZ256m((b|k|bk|kz)?)$", "^VPMADD52(H|L)UQZ256mbkz$")>; @@ -2301,7 +2301,7 @@ def : InstRW<[SPRWriteResGroup218, ReadAfterVecXLd], (instregex "^(V?)ROUNDS(D|S "^VRNDSCALEP(D|S)Z128rmbik(z?)$", "^VRNDSCALEP(D|S)Z128rmi((kz)?)$", "^VRNDSCALES(D|S)Zrmi$", - "^VRNDSCALES(D|S)Zrmi_Int((k|kz)?)$")>; + "^VRNDSCALES(D|S)Zrmi((k|kz)?)_Int$")>; def SPRWriteResGroup219 : SchedWriteRes<[SPRPort00_01]> { let ReleaseAtCycles = [2]; @@ -2313,7 +2313,7 @@ def : InstRW<[SPRWriteResGroup219], (instregex "^(V?)ROUND(PD|SS)ri$", "^(V?)ROUNDS(D|S)ri_Int$", "^VRNDSCALEP(D|S)Z(128|256)rri((k|kz)?)$", "^VRNDSCALES(D|S)Zrri$", - "^VRNDSCALES(D|S)Zrri(b?)_Int((k|kz)?)$", + "^VRNDSCALES(D|S)Zrri(b?)((k|kz)?)_Int$", "^VROUNDP(D|S)Yri$")>; def SPRWriteResGroup220 : SchedWriteRes<[SPRPort00_06]> { @@ -2530,7 +2530,7 @@ def SPRWriteResGroup249 : SchedWriteRes<[SPRPort01_05]> { let Latency = 4; } def : InstRW<[SPRWriteResGroup249], (instregex "^V(ADD|SUB)P(D|S)Z(128|256)rrkz$", - "^V(ADD|SUB)S(D|S)Zrr(b?)_Intkz$")>; + "^V(ADD|SUB)S(D|S)Zrr(b?)kz_Int$")>; def SPRWriteResGroup250 : SchedWriteRes<[SPRPort00_05]> { let Latency = 3; @@ -2545,11 +2545,11 @@ def SPRWriteResGroup251 : SchedWriteRes<[SPRPort00_01]> { let Latency = 6; } def : InstRW<[SPRWriteResGroup251], (instregex "^V(ADD|SUB)PHZ(128|256)rrk(z?)$", - "^V(ADD|SUB)SHZrr(b?)_Intk(z?)$", + "^V(ADD|SUB)SHZrr(b?)k(z?)_Int$", "^VCVT(T?)PH2(U?)WZ(128|256)rrk(z?)$", "^VCVT(U?)W2PHZ(128|256)rrk(z?)$", "^VF(N?)M(ADD|SUB)(132|213|231)PHZ(128|256)rk(z?)$", - "^VF(N?)M(ADD|SUB)(132|213|231)SHZr(b?)_Intk(z?)$", + "^VF(N?)M(ADD|SUB)(132|213|231)SHZr(b?)k(z?)_Int$", "^VFMADDSUB(132|213|231)PHZ(128|256)rk(z?)$", "^VFMSUBADD(132|213|231)PHZ(128|256)rk(z?)$", "^VGETEXPPHZ(128|256)rk(z?)$", @@ -2560,7 +2560,7 @@ def : InstRW<[SPRWriteResGroup251], (instregex "^V(ADD|SUB)PHZ(128|256)rrk(z?)$" "^VGETMANTSHZrri(k|bkz)$", "^VM(AX|IN)CPHZ(128|256)rrk(z?)$", "^VM(AX|IN|UL)PHZ(128|256)rrk(z?)$", - "^VM(AX|IN|UL)SHZrr(b?)_Intk(z?)$")>; + "^VM(AX|IN|UL)SHZrr(b?)k(z?)_Int$")>; def SPRWriteResGroup252 : SchedWriteRes<[SPRPort00]> { let Latency = 5; @@ -2745,7 +2745,7 @@ def : InstRW<[SPRWriteResGroup263, ReadAfterVecYLd], (instregex "^VCMPP(D|H|S)Z( "^VPTEST(N?)M(B|D|Q|W)Z((256)?)rm(k?)$", "^VPTEST(N?)M(D|Q)Z((256)?)rmb(k?)$")>; def : InstRW<[SPRWriteResGroup263, ReadAfterVecLd], (instregex "^VCMPS(D|H|S)Zrmi$", - "^VCMPS(D|H|S)Zrmi_Int(k?)$", + "^VCMPS(D|H|S)Zrmi(k?)_Int$", "^VFPCLASSS(D|H|S)Zmik$")>; def SPRWriteResGroup264 : SchedWriteRes<[SPRPort00, SPRPort02_03_10]> { @@ -3171,7 +3171,7 @@ def : InstRW<[SPRWriteResGroup314], (instregex "^VCVT(T?)PD2(U?)QQZ(128|256)rr(( "^VPLZCNT(D|Q)Z(128|256)rr((k|kz)?)$", "^VPMADD52(H|L)UQZ(128|256)r((k|kz)?)$", "^VSCALEFS(D|S)Zrr((k|kz)?)$", - "^VSCALEFS(D|S)Zrrb_Int((k|kz)?)$")>; + "^VSCALEFS(D|S)Zrrb((k|kz)?)_Int$")>; def : InstRW<[SPRWriteResGroup314, ReadAfterVecLd], (instregex "^VFIXUPIMMS(D|S)Zrrib((k|kz)?)$")>; def SPRWriteResGroup315 : SchedWriteRes<[SPRPort00_01, SPRPort02_03_10, SPRPort05]> { @@ -3300,7 +3300,7 @@ def SPRWriteResGroup331 : SchedWriteRes<[SPRPort00_01, SPRPort02_03_10]> { let NumMicroOps = 2; } def : InstRW<[SPRWriteResGroup331], (instregex "^VCVTPH2PSZ(128|256)rmk(z?)$")>; -def : InstRW<[SPRWriteResGroup331, ReadAfterVecLd], (instregex "^VCVTSH2SSZrm_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup331, ReadAfterVecLd], (instregex "^VCVTSH2SSZrmk(z?)_Int$")>; def : InstRW<[SPRWriteResGroup331, ReadAfterVecXLd], (instregex "^VPMADDUBSWZ128rmk(z?)$", "^VPMULH((U|RS)?)WZ128rmk(z?)$", "^VPMULLWZ128rmk(z?)$")>; @@ -3460,7 +3460,7 @@ def SPRWriteResGroup353 : SchedWriteRes<[SPRPort00_01, SPRPort00_01_05, SPRPort0 let Latency = 21; let NumMicroOps = 7; } -def : InstRW<[SPRWriteResGroup353, ReadAfterVecLd], (instregex "^VCVTSD2SHZrm_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup353, ReadAfterVecLd], (instregex "^VCVTSD2SHZrmk(z?)_Int$")>; def SPRWriteResGroup354 : SchedWriteRes<[SPRPort00_01, SPRPort00_01_05, SPRPort05]> { let ReleaseAtCycles = [2, 1, 1]; @@ -3475,7 +3475,7 @@ def SPRWriteResGroup355 : SchedWriteRes<[SPRPort00_01, SPRPort00_01_05, SPRPort0 let Latency = 14; let NumMicroOps = 4; } -def : InstRW<[SPRWriteResGroup355], (instregex "^VCVTSD2SHZrr(b?)_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup355], (instregex "^VCVTSD2SHZrr(b?)k(z?)_Int$")>; def SPRWriteResGroup356 : SchedWriteRes<[SPRPort00_01, SPRPort02_03_10, SPRPort05]> { let ReleaseAtCycles = [2, 1, 1]; @@ -3489,7 +3489,7 @@ def SPRWriteResGroup357 : SchedWriteRes<[SPRPort00_01, SPRPort02_03_10, SPRPort0 let Latency = 20; let NumMicroOps = 4; } -def : InstRW<[SPRWriteResGroup357, ReadAfterVecLd], (instregex "^VCVTSH2SDZrm_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup357, ReadAfterVecLd], (instregex "^VCVTSH2SDZrmk(z?)_Int$")>; def SPRWriteResGroup358 : SchedWriteRes<[SPRPort00_01, SPRPort05]> { let ReleaseAtCycles = [2, 1]; @@ -3504,7 +3504,7 @@ def SPRWriteResGroup359 : SchedWriteRes<[SPRPort00_01, SPRPort05]> { let Latency = 13; let NumMicroOps = 3; } -def : InstRW<[SPRWriteResGroup359], (instregex "^VCVTSH2SDZrr(b?)_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup359], (instregex "^VCVTSH2SDZrr(b?)k(z?)_Int$")>; def SPRWriteResGroup360 : SchedWriteRes<[SPRPort00, SPRPort00_01, SPRPort02_03_10]> { let Latency = 13; @@ -3523,7 +3523,7 @@ def : InstRW<[SPRWriteResGroup361], (instregex "^VCVT(T?)SH2(U?)SI((64)?)Zrr(b?) def SPRWriteResGroup362 : SchedWriteRes<[SPRPort00_01]> { let Latency = 8; } -def : InstRW<[SPRWriteResGroup362], (instregex "^VCVTSH2SSZrr(b?)_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup362], (instregex "^VCVTSH2SSZrr(b?)k(z?)_Int$")>; def SPRWriteResGroup363 : SchedWriteRes<[SPRPort00_01, SPRPort00_01_05, SPRPort02_03_10]> { let Latency = 14; @@ -3536,7 +3536,7 @@ def SPRWriteResGroup364 : SchedWriteRes<[SPRPort00_01, SPRPort00_01_05, SPRPort0 let Latency = 16; let NumMicroOps = 3; } -def : InstRW<[SPRWriteResGroup364, ReadAfterVecLd], (instregex "^VCVTSS2SHZrm_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup364, ReadAfterVecLd], (instregex "^VCVTSS2SHZrmk(z?)_Int$")>; def SPRWriteResGroup365 : SchedWriteRes<[SPRPort00_01, SPRPort00_01_05]> { let Latency = 6; @@ -3549,7 +3549,7 @@ def SPRWriteResGroup366 : SchedWriteRes<[SPRPort00_01, SPRPort00_01_05]> { let Latency = 9; let NumMicroOps = 2; } -def : InstRW<[SPRWriteResGroup366], (instregex "^VCVTSS2SHZrr(b?)_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup366], (instregex "^VCVTSS2SHZrr(b?)k(z?)_Int$")>; def SPRWriteResGroup367 : SchedWriteRes<[SPRPort05]> { let Latency = 5; @@ -3667,7 +3667,7 @@ def SPRWriteResGroup380 : SchedWriteRes<[SPRPort00, SPRPort02_03_10]> { let Latency = 21; let NumMicroOps = 2; } -def : InstRW<[SPRWriteResGroup380, ReadAfterVecLd], (instregex "^VDIVSHZrm_Int((k|kz)?)$")>; +def : InstRW<[SPRWriteResGroup380, ReadAfterVecLd], (instregex "^VDIVSHZrm((k|kz)?)_Int$")>; def : InstRW<[SPRWriteResGroup380, ReadAfterVecLd], (instrs VDIVSHZrm)>; def SPRWriteResGroup381 : SchedWriteRes<[SPRPort00]> { @@ -4884,7 +4884,7 @@ def SPRWriteResGroup534 : SchedWriteRes<[SPRPort00_01, SPRPort02_03_10]> { let NumMicroOps = 3; } def : InstRW<[SPRWriteResGroup534, ReadAfterVecXLd], (instregex "^VRNDSCALEPHZ128rm(b?)ik(z?)$", - "^VRNDSCALESHZrmi_Intk(z?)$", + "^VRNDSCALESHZrmik(z?)_Int$", "^VSCALEFPHZ128rm(bk|kz)$", "^VSCALEFPHZ128rm(k|bkz)$")>; def : InstRW<[SPRWriteResGroup534, ReadAfterVecYLd], (instregex "^VRNDSCALEPHZ256rm(b?)ik(z?)$", @@ -4898,9 +4898,9 @@ def SPRWriteResGroup535 : SchedWriteRes<[SPRPort00_01]> { let NumMicroOps = 2; } def : InstRW<[SPRWriteResGroup535], (instregex "^VRNDSCALEPHZ(128|256)rrik(z?)$", - "^VRNDSCALESHZrri(b?)_Intk(z?)$", + "^VRNDSCALESHZrri(b?)k(z?)_Int$", "^VSCALEFPHZ(128|256)rrk(z?)$", - "^VSCALEFSHZrrb_Intk(z?)$", + "^VSCALEFSHZrrbk(z?)_Int$", "^VSCALEFSHZrrk(z?)$")>; def SPRWriteResGroup536 : SchedWriteRes<[SPRPort00, SPRPort02_03_10]> { @@ -4944,7 +4944,7 @@ def SPRWriteResGroup540 : SchedWriteRes<[SPRPort00, SPRPort02_03_10]> { } def : InstRW<[SPRWriteResGroup540, ReadAfterVecXLd], (instregex "^VSQRTPDZ128m(bk|kz)$", "^VSQRTPDZ128m(k|bkz)$")>; -def : InstRW<[SPRWriteResGroup540, ReadAfterVecLd], (instregex "^VSQRTSDZm_Intk(z?)$")>; +def : InstRW<[SPRWriteResGroup540, ReadAfterVecLd], (instregex "^VSQRTSDZmk(z?)_Int$")>; def SPRWriteResGroup541 : SchedWriteRes<[SPRPort00, SPRPort00_05, SPRPort02_03_10]> { let ReleaseAtCycles = [2, 1, 1]; diff --git a/llvm/lib/Target/X86/X86ScheduleZnver4.td b/llvm/lib/Target/X86/X86ScheduleZnver4.td index 38f9b5e..c5478dd 100644 --- a/llvm/lib/Target/X86/X86ScheduleZnver4.td +++ b/llvm/lib/Target/X86/X86ScheduleZnver4.td @@ -1545,7 +1545,7 @@ def Zn4WriteSCALErr: SchedWriteRes<[Zn4FPFMisc23]> { let NumMicroOps = 2; } def : InstRW<[Zn4WriteSCALErr], (instregex - "V(SCALEF|REDUCE)(S|P)(S|D)(Z?|Z128?|Z256?)(rr|rrb|rrkz|rrik|rrikz|rri)(_Int?|_Intkz?)", + "V(SCALEF|REDUCE)(S|P)(S|D)(Z?|Z128?|Z256?)(rr|rrb|rrkz|rrik|rrikz|rri)(_Int?)", "(V?)REDUCE(PD|PS|SD|SS)(Z?|Z128?)(rri|rrikz|rrib)" )>; @@ -1585,7 +1585,7 @@ def : InstRW<[Zn4WriteSHIFTrr], (instregex "(V?)P(ROL|ROR)(D|Q|VD|VQ)(Z?|Z128?|Z256?)(rr|rrk|rrkz)", "(V?)P(ROL|ROR)(D|Q|VD|VQ)(Z256?)(ri|rik|rikz)", "(V?)P(ROL|ROR)(D|Q)(Z?|Z128?)(ri|rik|rikz)", - "VPSHUFBITQMBZ128rr", "VFMSUB231SSZr_Intkz" + "VPSHUFBITQMBZ128rr", "VFMSUB231SSZrkz_Int" )>; def Zn4WriteSHIFTri: SchedWriteRes<[Zn4FPFMisc01]> { diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 808f48e..c19bcfc 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -1650,6 +1650,13 @@ InstructionCost X86TTIImpl::getShuffleCost( return MatchingTypes ? TTI::TCC_Free : SubLT.first; } + // Attempt to match MOVSS (Idx == 0) or INSERTPS pattern. This will have + // been matched by improveShuffleKindFromMask as a SK_InsertSubvector of + // v1f32 (legalised to f32) into a v4f32. + if (LT.first == 1 && LT.second == MVT::v4f32 && SubLT.first == 1 && + SubLT.second == MVT::f32 && (Index == 0 || ST->hasSSE41())) + return 1; + // If the insertion isn't aligned, treat it like a 2-op shuffle. Kind = TTI::SK_PermuteTwoSrc; } @@ -1698,8 +1705,7 @@ InstructionCost X86TTIImpl::getShuffleCost( // We are going to permute multiple sources and the result will be in multiple // destinations. Providing an accurate cost only for splits where the element // type remains the same. - if ((Kind == TTI::SK_PermuteSingleSrc || Kind == TTI::SK_PermuteTwoSrc) && - LT.first != 1) { + if (LT.first != 1) { MVT LegalVT = LT.second; if (LegalVT.isVector() && LegalVT.getVectorElementType().getSizeInBits() == @@ -2227,9 +2233,18 @@ InstructionCost X86TTIImpl::getShuffleCost( { TTI::SK_PermuteTwoSrc, MVT::v4f32, 2 }, // 2*shufps }; - if (ST->hasSSE1()) + if (ST->hasSSE1()) { + if (LT.first == 1 && LT.second == MVT::v4f32 && Mask.size() == 4) { + // SHUFPS: both pairs must come from the same source register. + auto MatchSHUFPS = [](int X, int Y) { + return X < 0 || Y < 0 || ((X & 4) == (Y & 4)); + }; + if (MatchSHUFPS(Mask[0], Mask[1]) && MatchSHUFPS(Mask[2], Mask[3])) + return 1; + } if (const auto *Entry = CostTableLookup(SSE1ShuffleTbl, Kind, LT.second)) return LT.first * Entry->Cost; + } return BaseT::getShuffleCost(Kind, BaseTp, Mask, CostKind, Index, SubTp); } @@ -4789,9 +4804,12 @@ InstructionCost X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, MVT MScalarTy = LT.second.getScalarType(); auto IsCheapPInsrPExtrInsertPS = [&]() { // Assume pinsr/pextr XMM <-> GPR is relatively cheap on all targets. + // Inserting f32 into index0 is just movss. // Also, assume insertps is relatively cheap on all >= SSE41 targets. return (MScalarTy == MVT::i16 && ST->hasSSE2()) || (MScalarTy.isInteger() && ST->hasSSE41()) || + (MScalarTy == MVT::f32 && ST->hasSSE1() && Index == 0 && + Opcode == Instruction::InsertElement) || (MScalarTy == MVT::f32 && ST->hasSSE41() && Opcode == Instruction::InsertElement); }; diff --git a/llvm/lib/TargetParser/AArch64TargetParser.cpp b/llvm/lib/TargetParser/AArch64TargetParser.cpp index 50c9a56..7d0b8c3 100644 --- a/llvm/lib/TargetParser/AArch64TargetParser.cpp +++ b/llvm/lib/TargetParser/AArch64TargetParser.cpp @@ -48,17 +48,12 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA return {}; } -unsigned AArch64::getFMVPriority(ArrayRef<StringRef> Features) { - constexpr unsigned MaxFMVPriority = 1000; - unsigned Priority = 0; - unsigned NumFeatures = 0; - for (StringRef Feature : Features) { - if (auto Ext = parseFMVExtension(Feature)) { - Priority = std::max(Priority, Ext->Priority); - NumFeatures++; - } - } - return Priority + MaxFMVPriority * NumFeatures; +uint64_t AArch64::getFMVPriority(ArrayRef<StringRef> Features) { + uint64_t Priority = 0; + for (StringRef Feature : Features) + if (std::optional<FMVInfo> Info = parseFMVExtension(Feature)) + Priority |= (1ULL << Info->PriorityBit); + return Priority; } uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> Features) { @@ -73,7 +68,7 @@ uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> Features) { uint64_t FeaturesMask = 0; for (const FMVInfo &Info : getFMVInfo()) if (Info.ID && FeatureBits.Enabled.test(*Info.ID)) - FeaturesMask |= (1ULL << Info.Bit); + FeaturesMask |= (1ULL << Info.FeatureBit); return FeaturesMask; } diff --git a/llvm/lib/TargetParser/Host.cpp b/llvm/lib/TargetParser/Host.cpp index 45b4caf..9d1b7b8b 100644 --- a/llvm/lib/TargetParser/Host.cpp +++ b/llvm/lib/TargetParser/Host.cpp @@ -173,7 +173,7 @@ StringRef sys::detail::getHostCPUNameForARM(StringRef ProcCpuinfoContent) { // Read 32 lines from /proc/cpuinfo, which should contain the CPU part line // in all cases. SmallVector<StringRef, 32> Lines; - ProcCpuinfoContent.split(Lines, "\n"); + ProcCpuinfoContent.split(Lines, '\n'); // Look for the CPU implementer line. StringRef Implementer; @@ -436,7 +436,7 @@ StringRef sys::detail::getHostCPUNameForS390x(StringRef ProcCpuinfoContent) { // The "processor 0:" line comes after a fair amount of other information, // including a cache breakdown, but this should be plenty. SmallVector<StringRef, 32> Lines; - ProcCpuinfoContent.split(Lines, "\n"); + ProcCpuinfoContent.split(Lines, '\n'); // Look for the CPU features. SmallVector<StringRef, 32> CPUFeatures; @@ -478,7 +478,7 @@ StringRef sys::detail::getHostCPUNameForS390x(StringRef ProcCpuinfoContent) { StringRef sys::detail::getHostCPUNameForRISCV(StringRef ProcCpuinfoContent) { // There are 24 lines in /proc/cpuinfo SmallVector<StringRef> Lines; - ProcCpuinfoContent.split(Lines, "\n"); + ProcCpuinfoContent.split(Lines, '\n'); // Look for uarch line to determine cpu name StringRef UArch; @@ -1630,7 +1630,7 @@ StringRef sys::getHostCPUName() { #if defined(__linux__) StringRef sys::detail::getHostCPUNameForSPARC(StringRef ProcCpuinfoContent) { SmallVector<StringRef> Lines; - ProcCpuinfoContent.split(Lines, "\n"); + ProcCpuinfoContent.split(Lines, '\n'); // Look for cpu line to determine cpu name StringRef Cpu; @@ -1970,7 +1970,7 @@ const StringMap<bool> sys::getHostCPUFeatures() { return Features; SmallVector<StringRef, 32> Lines; - P->getBuffer().split(Lines, "\n"); + P->getBuffer().split(Lines, '\n'); SmallVector<StringRef, 32> CPUFeatures; diff --git a/llvm/lib/TargetParser/RISCVISAInfo.cpp b/llvm/lib/TargetParser/RISCVISAInfo.cpp index cafc9d3..d6e1eac 100644 --- a/llvm/lib/TargetParser/RISCVISAInfo.cpp +++ b/llvm/lib/TargetParser/RISCVISAInfo.cpp @@ -742,7 +742,8 @@ Error RISCVISAInfo::checkDependency() { bool HasZvl = MinVLen != 0; bool HasZcmt = Exts.count("zcmt") != 0; static constexpr StringLiteral XqciExts[] = { - {"xqcia"}, {"xqcics"}, {"xqcicsr"}, {"xqcilsm"}, {"xqcisls"}}; + {"xqcia"}, {"xqciac"}, {"xqcicli"}, {"xqcicm"}, + {"xqcics"}, {"xqcicsr"}, {"xqcilsm"}, {"xqcisls"}}; if (HasI && HasE) return getIncompatibleError("i", "e"); diff --git a/llvm/lib/ToolDrivers/llvm-lib/LibDriver.cpp b/llvm/lib/ToolDrivers/llvm-lib/LibDriver.cpp index 138d9fc..6ce06b4 100644 --- a/llvm/lib/ToolDrivers/llvm-lib/LibDriver.cpp +++ b/llvm/lib/ToolDrivers/llvm-lib/LibDriver.cpp @@ -171,6 +171,7 @@ static Expected<COFF::MachineTypes> getCOFFFileMachine(MemoryBufferRef MB) { uint16_t Machine = (*Obj)->getMachine(); if (Machine != COFF::IMAGE_FILE_MACHINE_I386 && Machine != COFF::IMAGE_FILE_MACHINE_AMD64 && + Machine != COFF::IMAGE_FILE_MACHINE_R4000 && Machine != COFF::IMAGE_FILE_MACHINE_ARMNT && !COFF::isAnyArm64(Machine)) { return createStringError(inconvertibleErrorCode(), "unknown machine: " + std::to_string(Machine)); @@ -195,6 +196,8 @@ static Expected<COFF::MachineTypes> getBitcodeFileMachine(MemoryBufferRef MB) { case Triple::aarch64: return T.isWindowsArm64EC() ? COFF::IMAGE_FILE_MACHINE_ARM64EC : COFF::IMAGE_FILE_MACHINE_ARM64; + case Triple::mipsel: + return COFF::IMAGE_FILE_MACHINE_R4000; default: return createStringError(inconvertibleErrorCode(), "unknown arch in target triple: " + *TripleStr); diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 45ee2d4..fe7b3b1 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -181,6 +181,7 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain /// of 'and' ops, then we also need to capture the fact that we saw an /// "and X, 1", so that's an extra return value for that case. +namespace { struct MaskOps { Value *Root = nullptr; APInt Mask; @@ -190,6 +191,7 @@ struct MaskOps { MaskOps(unsigned BitWidth, bool MatchAnds) : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {} }; +} // namespace /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a /// chain of 'and' or 'or' instructions looking for shift ops of a common source @@ -423,11 +425,8 @@ static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, Arg, 0, SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) { IRBuilder<> Builder(Call); - IRBuilderBase::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(Call->getFastMathFlags()); - - Value *NewSqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, - /*FMFSource=*/nullptr, "sqrt"); + Value *NewSqrt = + Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, Call, "sqrt"); Call->replaceAllUsesWith(NewSqrt); // Explicitly erase the old call because a call with side effects is not diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 240d089..7b59c39 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -69,7 +69,6 @@ static const char *const CoroIntrinsics[] = { "llvm.coro.async.context.dealloc", "llvm.coro.async.resume", "llvm.coro.async.size.replace", - "llvm.coro.async.store_resume", "llvm.coro.await.suspend.bool", "llvm.coro.await.suspend.handle", "llvm.coro.await.suspend.void", diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp index fde43bb..c3d0a1a 100644 --- a/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -1950,9 +1950,8 @@ Expected<bool> FunctionImporter::importFunctions( SrcModule->setPartialSampleProfileRatio(Index); // Link in the specified functions. - if (renameModuleForThinLTO(*SrcModule, Index, ClearDSOLocalOnDeclarations, - &GlobalsToImport)) - return true; + renameModuleForThinLTO(*SrcModule, Index, ClearDSOLocalOnDeclarations, + &GlobalsToImport); if (PrintImports) { for (const auto *GV : GlobalsToImport) @@ -2026,11 +2025,8 @@ static bool doImportingForModuleForTest( // Next we need to promote to global scope and rename any local values that // are potentially exported to other modules. - if (renameModuleForThinLTO(M, *Index, /*ClearDSOLocalOnDeclarations=*/false, - /*GlobalsToImport=*/nullptr)) { - errs() << "Error renaming module\n"; - return true; - } + renameModuleForThinLTO(M, *Index, /*ClearDSOLocalOnDeclarations=*/false, + /*GlobalsToImport=*/nullptr); // Perform the import now. auto ModuleLoader = [&M](StringRef Identifier) { diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index 96956481..449d64d 100644 --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -66,19 +66,19 @@ static cl::opt<unsigned> MaxCodeSizeGrowth( "Maximum codesize growth allowed per function")); static cl::opt<unsigned> MinCodeSizeSavings( - "funcspec-min-codesize-savings", cl::init(20), cl::Hidden, cl::desc( - "Reject specializations whose codesize savings are less than this" - "much percent of the original function size")); + "funcspec-min-codesize-savings", cl::init(20), cl::Hidden, + cl::desc("Reject specializations whose codesize savings are less than this " + "much percent of the original function size")); static cl::opt<unsigned> MinLatencySavings( "funcspec-min-latency-savings", cl::init(40), cl::Hidden, - cl::desc("Reject specializations whose latency savings are less than this" + cl::desc("Reject specializations whose latency savings are less than this " "much percent of the original function size")); static cl::opt<unsigned> MinInliningBonus( - "funcspec-min-inlining-bonus", cl::init(300), cl::Hidden, cl::desc( - "Reject specializations whose inlining bonus is less than this" - "much percent of the original function size")); + "funcspec-min-inlining-bonus", cl::init(300), cl::Hidden, + cl::desc("Reject specializations whose inlining bonus is less than this " + "much percent of the original function size")); static cl::opt<bool> SpecializeOnAddress( "funcspec-on-address", cl::init(false), cl::Hidden, cl::desc( diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 16a80e9..78cd249 100644 --- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -105,7 +105,7 @@ static cl::opt<int> ColdCCRelFreq( "coldcc-rel-freq", cl::Hidden, cl::init(2), cl::desc( "Maximum block frequency, expressed as a percentage of caller's " - "entry frequency, for a call site to be considered cold for enabling" + "entry frequency, for a call site to be considered cold for enabling " "coldcc")); /// Is this global variable possibly used by a leak checker as a root? If so, diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp index 1bf7ff4..016db55 100644 --- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp +++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -122,6 +122,20 @@ static cl::opt<unsigned> cl::desc("Max depth to recursively search for missing " "frames through tail calls.")); +// By default enable cloning of callsites involved with recursive cycles +static cl::opt<bool> AllowRecursiveCallsites( + "memprof-allow-recursive-callsites", cl::init(true), cl::Hidden, + cl::desc("Allow cloning of callsites involved in recursive cycles")); + +// When disabled, try to detect and prevent cloning of recursive contexts. +// This is only necessary until we support cloning through recursive cycles. +// Leave on by default for now, as disabling requires a little bit of compile +// time overhead and doesn't affect correctness, it will just inflate the cold +// hinted bytes reporting a bit when -memprof-report-hinted-sizes is enabled. +static cl::opt<bool> AllowRecursiveContexts( + "memprof-allow-recursive-contexts", cl::init(true), cl::Hidden, + cl::desc("Allow cloning of contexts through recursive cycles")); + namespace llvm { cl::opt<bool> EnableMemProfContextDisambiguation( "enable-memprof-context-disambiguation", cl::init(false), cl::Hidden, @@ -1236,9 +1250,13 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addStackNodesForMIB( StackEntryIdToContextNodeMap[StackId] = StackNode; StackNode->OrigStackOrAllocId = StackId; } - auto Ins = StackIdSet.insert(StackId); - if (!Ins.second) - StackNode->Recursive = true; + // Marking a node recursive will prevent its cloning completely, even for + // non-recursive contexts flowing through it. + if (!AllowRecursiveCallsites) { + auto Ins = StackIdSet.insert(StackId); + if (!Ins.second) + StackNode->Recursive = true; + } StackNode->AllocTypes |= (uint8_t)AllocType; PrevNode->addOrUpdateCallerEdge(StackNode, AllocType, LastContextId); PrevNode = StackNode; @@ -1375,8 +1393,11 @@ static void checkNode(const ContextNode<DerivedCCG, FuncTy, CallTy> *Node, set_union(CallerEdgeContextIds, Edge->ContextIds); } // Node can have more context ids than callers if some contexts terminate at - // node and some are longer. - assert(NodeContextIds == CallerEdgeContextIds || + // node and some are longer. If we are allowing recursive callsites but + // haven't disabled recursive contexts, this will be violated for + // incompletely cloned recursive cycles, so skip the checking in that case. + assert((AllowRecursiveCallsites && AllowRecursiveContexts) || + NodeContextIds == CallerEdgeContextIds || set_is_subset(CallerEdgeContextIds, NodeContextIds)); } if (Node->CalleeEdges.size()) { @@ -3370,6 +3391,21 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( assert(Node->AllocTypes != (uint8_t)AllocationType::None); + DenseSet<uint32_t> RecursiveContextIds; + // If we are allowing recursive callsites, but have also disabled recursive + // contexts, look for context ids that show up in multiple caller edges. + if (AllowRecursiveCallsites && !AllowRecursiveContexts) { + DenseSet<uint32_t> AllCallerContextIds; + for (auto &CE : Node->CallerEdges) { + // Resize to the largest set of caller context ids, since we know the + // final set will be at least that large. + AllCallerContextIds.reserve(CE->getContextIds().size()); + for (auto Id : CE->getContextIds()) + if (!AllCallerContextIds.insert(Id).second) + RecursiveContextIds.insert(Id); + } + } + // Iterate until we find no more opportunities for disambiguating the alloc // types via cloning. In most cases this loop will terminate once the Node // has a single allocation type, in which case no more cloning is needed. @@ -3394,6 +3430,9 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( // allocation. auto CallerEdgeContextsForAlloc = set_intersection(CallerEdge->getContextIds(), AllocContextIds); + if (!RecursiveContextIds.empty()) + CallerEdgeContextsForAlloc = + set_difference(CallerEdgeContextsForAlloc, RecursiveContextIds); if (CallerEdgeContextsForAlloc.empty()) { ++EI; continue; diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index b40ab35..67585e9 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -129,7 +129,7 @@ static cl::opt<bool> PrintModuleBeforeOptimizations( static cl::opt<bool> AlwaysInlineDeviceFunctions( "openmp-opt-inline-device", - cl::desc("Inline all applicible functions on the device."), cl::Hidden, + cl::desc("Inline all applicable functions on the device."), cl::Hidden, cl::init(false)); static cl::opt<bool> diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp index 603beb3..b978c54 100644 --- a/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -162,7 +162,7 @@ static cl::opt<bool> ProfileSampleBlockAccurate( static cl::opt<bool> ProfileAccurateForSymsInList( "profile-accurate-for-symsinlist", cl::Hidden, cl::init(true), cl::desc("For symbols in profile symbol list, regard their profiles to " - "be accurate. It may be overriden by profile-sample-accurate. ")); + "be accurate. It may be overridden by profile-sample-accurate. ")); static cl::opt<bool> ProfileMergeInlinee( "sample-profile-merge-inlinee", cl::Hidden, cl::init(true), @@ -193,9 +193,10 @@ static cl::opt<bool> ProfileSizeInline( // and inline the hot functions (that are skipped in this pass). static cl::opt<bool> DisableSampleLoaderInlining( "disable-sample-loader-inlining", cl::Hidden, cl::init(false), - cl::desc("If true, artifically skip inline transformation in sample-loader " - "pass, and merge (or scale) profiles (as configured by " - "--sample-profile-merge-inlinee).")); + cl::desc( + "If true, artificially skip inline transformation in sample-loader " + "pass, and merge (or scale) profiles (as configured by " + "--sample-profile-merge-inlinee).")); namespace llvm { cl::opt<bool> @@ -255,7 +256,7 @@ static cl::opt<unsigned> PrecentMismatchForStalenessError( static cl::opt<bool> CallsitePrioritizedInline( "sample-profile-prioritized-inline", cl::Hidden, - cl::desc("Use call site prioritized inlining for sample profile loader." + cl::desc("Use call site prioritized inlining for sample profile loader. " "Currently only CSSPGO is supported.")); static cl::opt<bool> UsePreInlinerDecision( diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 7a184a1..73876d0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1326,6 +1326,18 @@ Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS, R->setHasNoUnsignedWrap(NUWOut); return R; } + + // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 + const APInt *C1, *C2; + if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) { + APInt One(C2->getBitWidth(), 1); + APInt MinusC1 = -(*C1); + if (MinusC1 == (One << *C2)) { + Constant *NewRHS = ConstantInt::get(RHS->getType(), MinusC1); + return BinaryOperator::CreateSRem(RHS, NewRHS); + } + } + return nullptr; } @@ -1623,17 +1635,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); - // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 - const APInt *C1, *C2; - if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) { - APInt one(C2->getBitWidth(), 1); - APInt minusC1 = -(*C1); - if (minusC1 == (one << *C2)) { - Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1); - return BinaryOperator::CreateSRem(RHS, NewRHS); - } - } - + const APInt *C1; // (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) && C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countl_zero())) { @@ -2845,12 +2847,11 @@ Instruction *InstCombinerImpl::hoistFNegAboveFMulFDiv(Value *FNegOp, // Make sure to preserve flags and metadata on the call. if (II->getIntrinsicID() == Intrinsic::ldexp) { FastMathFlags FMF = FMFSource.getFastMathFlags() | II->getFastMathFlags(); - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(FMF); - - CallInst *New = Builder.CreateCall( - II->getCalledFunction(), - {Builder.CreateFNeg(II->getArgOperand(0)), II->getArgOperand(1)}); + CallInst *New = + Builder.CreateCall(II->getCalledFunction(), + {Builder.CreateFNegFMF(II->getArgOperand(0), FMF), + II->getArgOperand(1)}); + New->setFastMathFlags(FMF); New->copyMetadata(*II); return New; } @@ -2932,12 +2933,8 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { // flags the copysign doesn't also have. FastMathFlags FMF = I.getFastMathFlags(); FMF &= cast<FPMathOperator>(OneUse)->getFastMathFlags(); - - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(FMF); - - Value *NegY = Builder.CreateFNeg(Y); - Value *NewCopySign = Builder.CreateCopySign(X, NegY); + Value *NegY = Builder.CreateFNegFMF(Y, FMF); + Value *NewCopySign = Builder.CreateCopySign(X, NegY, FMF); return replaceInstUsesWith(I, NewCopySign); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index e576eea4..f82a557 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -39,11 +39,11 @@ static Value *getNewICmpValue(unsigned Code, bool Sign, Value *LHS, Value *RHS, /// This is the complement of getFCmpCode, which turns an opcode and two /// operands into either a FCmp instruction, or a true/false constant. static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, - InstCombiner::BuilderTy &Builder) { + InstCombiner::BuilderTy &Builder, FMFSource FMF) { FCmpInst::Predicate NewPred; if (Constant *TorF = getPredForFCmpCode(Code, LHS->getType(), NewPred)) return TorF; - return Builder.CreateFCmp(NewPred, LHS, RHS); + return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF); } /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise @@ -513,7 +513,8 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric( /// into a single (icmp(A & X) ==/!= Y). static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, bool IsLogical, - InstCombiner::BuilderTy &Builder) { + InstCombiner::BuilderTy &Builder, + const SimplifyQuery &Q) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); std::optional<std::pair<unsigned, unsigned>> MaskPair = @@ -586,93 +587,107 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, return Builder.CreateICmp(NewCC, NewAnd2, A); } - // Remaining cases assume at least that B and D are constant, and depend on - // their actual values. This isn't strictly necessary, just a "handle the - // easy cases for now" decision. const APInt *ConstB, *ConstD; - if (!match(B, m_APInt(ConstB)) || !match(D, m_APInt(ConstD))) - return nullptr; - - if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { - // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and - // (icmp ne (A & B), B) & (icmp ne (A & D), D) - // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) - // Only valid if one of the masks is a superset of the other (check "B&D" is - // the same as either B or D). - APInt NewMask = *ConstB & *ConstD; - if (NewMask == *ConstB) - return LHS; - else if (NewMask == *ConstD) - return RHS; - } - - if (Mask & AMask_NotAllOnes) { - // (icmp ne (A & B), B) & (icmp ne (A & D), D) - // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) - // Only valid if one of the masks is a superset of the other (check "B|D" is - // the same as either B or D). - APInt NewMask = *ConstB | *ConstD; - if (NewMask == *ConstB) - return LHS; - else if (NewMask == *ConstD) - return RHS; - } - - if (Mask & (BMask_Mixed | BMask_NotMixed)) { - // Mixed: - // (icmp eq (A & B), C) & (icmp eq (A & D), E) - // We already know that B & C == C && D & E == E. - // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of - // C and E, which are shared by both the mask B and the mask D, don't - // contradict, then we can transform to - // -> (icmp eq (A & (B|D)), (C|E)) - // Currently, we only handle the case of B, C, D, and E being constant. - // We can't simply use C and E because we might actually handle - // (icmp ne (A & B), B) & (icmp eq (A & D), D) - // with B and D, having a single bit set. - - // NotMixed: - // (icmp ne (A & B), C) & (icmp ne (A & D), E) - // -> (icmp ne (A & (B & D)), (C & E)) - // Check the intersection (B & D) for inequality. - // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B - // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both the - // B and the D, don't contradict. - // Note that we can assume (~B & C) == 0 && (~D & E) == 0, previous - // operation should delete these icmps if it hadn't been met. - - const APInt *OldConstC, *OldConstE; - if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) - return nullptr; - - auto FoldBMixed = [&](ICmpInst::Predicate CC, bool IsNot) -> Value * { - CC = IsNot ? CmpInst::getInversePredicate(CC) : CC; - const APInt ConstC = PredL != CC ? *ConstB ^ *OldConstC : *OldConstC; - const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE; + if (match(B, m_APInt(ConstB)) && match(D, m_APInt(ConstD))) { + if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { + // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and + // (icmp ne (A & B), B) & (icmp ne (A & D), D) + // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) + // Only valid if one of the masks is a superset of the other (check "B&D" + // is the same as either B or D). + APInt NewMask = *ConstB & *ConstD; + if (NewMask == *ConstB) + return LHS; + if (NewMask == *ConstD) + return RHS; + } - if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) - return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd); + if (Mask & AMask_NotAllOnes) { + // (icmp ne (A & B), B) & (icmp ne (A & D), D) + // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) + // Only valid if one of the masks is a superset of the other (check "B|D" + // is the same as either B or D). + APInt NewMask = *ConstB | *ConstD; + if (NewMask == *ConstB) + return LHS; + if (NewMask == *ConstD) + return RHS; + } - if (IsNot && !ConstB->isSubsetOf(*ConstD) && !ConstD->isSubsetOf(*ConstB)) + if (Mask & (BMask_Mixed | BMask_NotMixed)) { + // Mixed: + // (icmp eq (A & B), C) & (icmp eq (A & D), E) + // We already know that B & C == C && D & E == E. + // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of + // C and E, which are shared by both the mask B and the mask D, don't + // contradict, then we can transform to + // -> (icmp eq (A & (B|D)), (C|E)) + // Currently, we only handle the case of B, C, D, and E being constant. + // We can't simply use C and E because we might actually handle + // (icmp ne (A & B), B) & (icmp eq (A & D), D) + // with B and D, having a single bit set. + + // NotMixed: + // (icmp ne (A & B), C) & (icmp ne (A & D), E) + // -> (icmp ne (A & (B & D)), (C & E)) + // Check the intersection (B & D) for inequality. + // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B + // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both + // the B and the D, don't contradict. Note that we can assume (~B & C) == + // 0 && (~D & E) == 0, previous operation should delete these icmps if it + // hadn't been met. + + const APInt *OldConstC, *OldConstE; + if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) return nullptr; - APInt BD, CE; - if (IsNot) { - BD = *ConstB & *ConstD; - CE = ConstC & ConstE; - } else { - BD = *ConstB | *ConstD; - CE = ConstC | ConstE; - } - Value *NewAnd = Builder.CreateAnd(A, BD); - Value *CEVal = ConstantInt::get(A->getType(), CE); - return Builder.CreateICmp(CC, CEVal, NewAnd); - }; + auto FoldBMixed = [&](ICmpInst::Predicate CC, bool IsNot) -> Value * { + CC = IsNot ? CmpInst::getInversePredicate(CC) : CC; + const APInt ConstC = PredL != CC ? *ConstB ^ *OldConstC : *OldConstC; + const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE; + + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) + return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd); + + if (IsNot && !ConstB->isSubsetOf(*ConstD) && + !ConstD->isSubsetOf(*ConstB)) + return nullptr; + + APInt BD, CE; + if (IsNot) { + BD = *ConstB & *ConstD; + CE = ConstC & ConstE; + } else { + BD = *ConstB | *ConstD; + CE = ConstC | ConstE; + } + Value *NewAnd = Builder.CreateAnd(A, BD); + Value *CEVal = ConstantInt::get(A->getType(), CE); + return Builder.CreateICmp(CC, CEVal, NewAnd); + }; + + if (Mask & BMask_Mixed) + return FoldBMixed(NewCC, false); + if (Mask & BMask_NotMixed) // can be else also + return FoldBMixed(NewCC, true); + } + } - if (Mask & BMask_Mixed) - return FoldBMixed(NewCC, false); - if (Mask & BMask_NotMixed) // can be else also - return FoldBMixed(NewCC, true); + // (icmp eq (A & B), 0) | (icmp eq (A & D), 0) + // -> (icmp ne (A & (B|D)), (B|D)) + // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) + // -> (icmp eq (A & (B|D)), (B|D)) + // iff B and D is known to be a power of two + if (Mask & Mask_NotAllZeros && + isKnownToBeAPowerOfTwo(B, /*OrZero=*/false, /*Depth=*/0, Q) && + isKnownToBeAPowerOfTwo(D, /*OrZero=*/false, /*Depth=*/0, Q)) { + // If this is a logical and/or, then we must prevent propagation of a + // poison value from the RHS by inserting freeze. + if (IsLogical) + D = Builder.CreateFreeze(D); + Value *Mask = Builder.CreateOr(B, D); + Value *Masked = Builder.CreateAnd(A, Mask); + return Builder.CreateICmp(NewCC, Masked, Mask); } return nullptr; } @@ -775,46 +790,6 @@ foldAndOrOfICmpsWithPow2AndWithZero(InstCombiner::BuilderTy &Builder, return Builder.CreateICmp(Pred, And, Op); } -// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) -// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) -Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, - ICmpInst *RHS, - Instruction *CxtI, - bool IsAnd, - bool IsLogical) { - CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ; - if (LHS->getPredicate() != Pred || RHS->getPredicate() != Pred) - return nullptr; - - if (!match(LHS->getOperand(1), m_Zero()) || - !match(RHS->getOperand(1), m_Zero())) - return nullptr; - - Value *L1, *L2, *R1, *R2; - if (match(LHS->getOperand(0), m_And(m_Value(L1), m_Value(L2))) && - match(RHS->getOperand(0), m_And(m_Value(R1), m_Value(R2)))) { - if (L1 == R2 || L2 == R2) - std::swap(R1, R2); - if (L2 == R1) - std::swap(L1, L2); - - if (L1 == R1 && - isKnownToBeAPowerOfTwo(L2, false, 0, CxtI) && - isKnownToBeAPowerOfTwo(R2, false, 0, CxtI)) { - // If this is a logical and/or, then we must prevent propagation of a - // poison value from the RHS by inserting freeze. - if (IsLogical) - R2 = Builder.CreateFreeze(R2); - Value *Mask = Builder.CreateOr(L2, R2); - Value *Masked = Builder.CreateAnd(L1, Mask); - auto NewPred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; - return Builder.CreateICmp(NewPred, Masked, Mask); - } - } - - return nullptr; -} - /// General pattern: /// X & Y /// @@ -1429,12 +1404,8 @@ static Value *matchIsFiniteTest(InstCombiner::BuilderTy &Builder, FCmpInst *LHS, !matchUnorderedInfCompare(PredR, RHS0, RHS1)) return nullptr; - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - FastMathFlags FMF = LHS->getFastMathFlags(); - FMF &= RHS->getFastMathFlags(); - Builder.setFastMathFlags(FMF); - - return Builder.CreateFCmp(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1); + return Builder.CreateFCmpFMF(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1, + FMFSource::intersect(LHS, RHS)); } Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, @@ -1470,12 +1441,8 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, // Intersect the fast math flags. // TODO: We can union the fast math flags unless this is a logical select. - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - FastMathFlags FMF = LHS->getFastMathFlags(); - FMF &= RHS->getFastMathFlags(); - Builder.setFastMathFlags(FMF); - - return getFCmpValue(NewPred, LHS0, LHS1, Builder); + return getFCmpValue(NewPred, LHS0, LHS1, Builder, + FMFSource::intersect(LHS, RHS)); } // This transform is not valid for a logical select. @@ -1492,10 +1459,8 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, // Ignore the constants because they are obviously not NANs: // (fcmp ord x, 0.0) & (fcmp ord y, 0.0) -> (fcmp ord x, y) // (fcmp uno x, 0.0) | (fcmp uno y, 0.0) -> (fcmp uno x, y) - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - Builder.setFastMathFlags(LHS->getFastMathFlags() & - RHS->getFastMathFlags()); - return Builder.CreateFCmp(PredL, LHS0, RHS0); + return Builder.CreateFCmpFMF(PredL, LHS0, RHS0, + FMFSource::intersect(LHS, RHS)); } } @@ -1557,15 +1522,14 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, std::swap(PredL, PredR); } if (IsLessThanOrLessEqual(IsAnd ? PredL : PredR)) { - BuilderTy::FastMathFlagGuard Guard(Builder); FastMathFlags NewFlag = LHS->getFastMathFlags(); if (!IsLogicalSelect) NewFlag |= RHS->getFastMathFlags(); - Builder.setFastMathFlags(NewFlag); - Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, LHS0); - return Builder.CreateFCmp(PredL, FAbs, - ConstantFP::get(LHS0->getType(), *LHSC)); + Value *FAbs = + Builder.CreateUnaryIntrinsic(Intrinsic::fabs, LHS0, NewFlag); + return Builder.CreateFCmpFMF( + PredL, FAbs, ConstantFP::get(LHS0->getType(), *LHSC), NewFlag); } } @@ -2372,6 +2336,26 @@ static Value *simplifyAndOrWithOpReplaced(Value *V, Value *Op, Value *RepOp, return IC.Builder.CreateBinOp(I->getOpcode(), NewOp0, NewOp1); } +/// Reassociate and/or expressions to see if we can fold the inner and/or ops. +/// TODO: Make this recursive; it's a little tricky because an arbitrary +/// number of and/or instructions might have to be created. +Value *InstCombinerImpl::reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, + Instruction &I, bool IsAnd, + bool RHSIsLogical) { + Instruction::BinaryOps Opcode = IsAnd ? Instruction::And : Instruction::Or; + // LHS bop (X lop Y) --> (LHS bop X) lop Y + // LHS bop (X bop Y) --> (LHS bop X) bop Y + if (Value *Res = foldBooleanAndOr(LHS, X, I, IsAnd, /*IsLogical=*/false)) + return RHSIsLogical ? Builder.CreateLogicalOp(Opcode, Res, Y) + : Builder.CreateBinOp(Opcode, Res, Y); + // LHS bop (X bop Y) --> X bop (LHS bop Y) + // LHS bop (X lop Y) --> X lop (LHS bop Y) + if (Value *Res = foldBooleanAndOr(LHS, Y, I, IsAnd, /*IsLogical=*/false)) + return RHSIsLogical ? Builder.CreateLogicalOp(Opcode, X, Res) + : Builder.CreateBinOp(Opcode, X, Res); + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -2755,31 +2739,17 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { foldBooleanAndOr(Op0, Op1, I, /*IsAnd=*/true, /*IsLogical=*/false)) return replaceInstUsesWith(I, Res); - // TODO: Make this recursive; it's a little tricky because an arbitrary - // number of 'and' instructions might have to be created. if (match(Op1, m_OneUse(m_LogicalAnd(m_Value(X), m_Value(Y))))) { bool IsLogical = isa<SelectInst>(Op1); - // Op0 & (X && Y) --> (Op0 && X) && Y - if (Value *Res = foldBooleanAndOr(Op0, X, I, /* IsAnd */ true, IsLogical)) - return replaceInstUsesWith(I, IsLogical ? Builder.CreateLogicalAnd(Res, Y) - : Builder.CreateAnd(Res, Y)); - // Op0 & (X && Y) --> X && (Op0 & Y) - if (Value *Res = foldBooleanAndOr(Op0, Y, I, /* IsAnd */ true, - /* IsLogical */ false)) - return replaceInstUsesWith(I, IsLogical ? Builder.CreateLogicalAnd(X, Res) - : Builder.CreateAnd(X, Res)); + if (auto *V = reassociateBooleanAndOr(Op0, X, Y, I, /*IsAnd=*/true, + /*RHSIsLogical=*/IsLogical)) + return replaceInstUsesWith(I, V); } if (match(Op0, m_OneUse(m_LogicalAnd(m_Value(X), m_Value(Y))))) { bool IsLogical = isa<SelectInst>(Op0); - // (X && Y) & Op1 --> (X && Op1) && Y - if (Value *Res = foldBooleanAndOr(X, Op1, I, /* IsAnd */ true, IsLogical)) - return replaceInstUsesWith(I, IsLogical ? Builder.CreateLogicalAnd(Res, Y) - : Builder.CreateAnd(Res, Y)); - // (X && Y) & Op1 --> X && (Y & Op1) - if (Value *Res = foldBooleanAndOr(Y, Op1, I, /* IsAnd */ true, - /* IsLogical */ false)) - return replaceInstUsesWith(I, IsLogical ? Builder.CreateLogicalAnd(X, Res) - : Builder.CreateAnd(X, Res)); + if (auto *V = reassociateBooleanAndOr(Op1, X, Y, I, /*IsAnd=*/true, + /*RHSIsLogical=*/IsLogical)) + return replaceInstUsesWith(I, V); } if (Instruction *FoldedFCmps = reassociateFCmps(I, Builder)) @@ -3330,12 +3300,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsLogical) { const SimplifyQuery Q = SQ.getWithInstruction(&I); - // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) - // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) - // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &I, IsAnd, IsLogical)) - return V; - ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); @@ -3362,7 +3326,7 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // handle (roughly): // (icmp ne (A & B), C) | (icmp ne (A & D), E) // (icmp eq (A & B), C) & (icmp eq (A & D), E) - if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder)) + if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder, Q)) return V; if (Value *V = @@ -3840,31 +3804,17 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { foldBooleanAndOr(Op0, Op1, I, /*IsAnd=*/false, /*IsLogical=*/false)) return replaceInstUsesWith(I, Res); - // TODO: Make this recursive; it's a little tricky because an arbitrary - // number of 'or' instructions might have to be created. if (match(Op1, m_OneUse(m_LogicalOr(m_Value(X), m_Value(Y))))) { bool IsLogical = isa<SelectInst>(Op1); - // Op0 | (X || Y) --> (Op0 || X) || Y - if (Value *Res = foldBooleanAndOr(Op0, X, I, /* IsAnd */ false, IsLogical)) - return replaceInstUsesWith(I, IsLogical ? Builder.CreateLogicalOr(Res, Y) - : Builder.CreateOr(Res, Y)); - // Op0 | (X || Y) --> X || (Op0 | Y) - if (Value *Res = foldBooleanAndOr(Op0, Y, I, /* IsAnd */ false, - /* IsLogical */ false)) - return replaceInstUsesWith(I, IsLogical ? Builder.CreateLogicalOr(X, Res) - : Builder.CreateOr(X, Res)); + if (auto *V = reassociateBooleanAndOr(Op0, X, Y, I, /*IsAnd=*/false, + /*RHSIsLogical=*/IsLogical)) + return replaceInstUsesWith(I, V); } if (match(Op0, m_OneUse(m_LogicalOr(m_Value(X), m_Value(Y))))) { bool IsLogical = isa<SelectInst>(Op0); - // (X || Y) | Op1 --> (X || Op1) || Y - if (Value *Res = foldBooleanAndOr(X, Op1, I, /* IsAnd */ false, IsLogical)) - return replaceInstUsesWith(I, IsLogical ? Builder.CreateLogicalOr(Res, Y) - : Builder.CreateOr(Res, Y)); - // (X || Y) | Op1 --> X || (Y | Op1) - if (Value *Res = foldBooleanAndOr(Y, Op1, I, /* IsAnd */ false, - /* IsLogical */ false)) - return replaceInstUsesWith(I, IsLogical ? Builder.CreateLogicalOr(X, Res) - : Builder.CreateOr(X, Res)); + if (auto *V = reassociateBooleanAndOr(Op1, X, Y, I, /*IsAnd=*/false, + /*RHSIsLogical=*/IsLogical)) + return replaceInstUsesWith(I, V); } if (Instruction *FoldedFCmps = reassociateFCmps(I, Builder)) @@ -4981,8 +4931,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // (A & B) ^ (A | C) --> A ? ~B : C -- There are 4 commuted variants. if (I.getType()->isIntOrIntVectorTy(1) && - match(Op0, m_OneUse(m_LogicalAnd(m_Value(A), m_Value(B)))) && - match(Op1, m_OneUse(m_LogicalOr(m_Value(C), m_Value(D))))) { + match(&I, m_c_Xor(m_OneUse(m_LogicalAnd(m_Value(A), m_Value(B))), + m_OneUse(m_LogicalOr(m_Value(C), m_Value(D)))))) { bool NeedFreeze = isa<SelectInst>(Op0) && isa<SelectInst>(Op1) && B == D; if (B == C || B == D) std::swap(A, B); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index fd38738..c55c40c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -839,6 +839,35 @@ InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { if (OptimizeOverflowCheck(WO->getBinaryOp(), WO->isSigned(), WO->getLHS(), WO->getRHS(), *WO, OperationResult, OverflowResult)) return createOverflowTuple(WO, OperationResult, OverflowResult); + + // See whether we can optimize the overflow check with assumption information. + for (User *U : WO->users()) { + if (!match(U, m_ExtractValue<1>(m_Value()))) + continue; + + for (auto &AssumeVH : AC.assumptionsFor(U)) { + if (!AssumeVH) + continue; + CallInst *I = cast<CallInst>(AssumeVH); + if (!match(I->getArgOperand(0), m_Not(m_Specific(U)))) + continue; + if (!isValidAssumeForContext(I, II, /*DT=*/nullptr, + /*AllowEphemerals=*/true)) + continue; + Value *Result = + Builder.CreateBinOp(WO->getBinaryOp(), WO->getLHS(), WO->getRHS()); + Result->takeName(WO); + if (auto *Inst = dyn_cast<Instruction>(Result)) { + if (WO->isSigned()) + Inst->setHasNoSignedWrap(); + else + Inst->setHasNoUnsignedWrap(); + } + return createOverflowTuple(WO, Result, + ConstantInt::getFalse(U->getType())); + } + } + return nullptr; } @@ -2644,8 +2673,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Propagate sign argument through nested calls: // copysign Mag, (copysign ?, X) --> copysign Mag, X Value *X; - if (match(Sign, m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(X)))) - return replaceOperand(*II, 1, X); + if (match(Sign, m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(X)))) { + Value *CopySign = + Builder.CreateCopySign(Mag, X, FMFSource::intersect(II, Sign)); + return replaceInstUsesWith(*II, CopySign); + } // Clear sign-bit of constant magnitude: // copysign -MagC, X --> copysign MagC, X diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 0b93799..4ec1af3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1852,15 +1852,13 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { Value *X; Instruction *Op = dyn_cast<Instruction>(FPT.getOperand(0)); if (Op && Op->hasOneUse()) { - IRBuilder<>::FastMathFlagGuard FMFG(Builder); FastMathFlags FMF = FPT.getFastMathFlags(); if (auto *FPMO = dyn_cast<FPMathOperator>(Op)) FMF &= FPMO->getFastMathFlags(); - Builder.setFastMathFlags(FMF); if (match(Op, m_FNeg(m_Value(X)))) { - Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); - Value *Neg = Builder.CreateFNeg(InnerTrunc); + Value *InnerTrunc = Builder.CreateFPTruncFMF(X, Ty, FMF); + Value *Neg = Builder.CreateFNegFMF(InnerTrunc, FMF); return replaceInstUsesWith(FPT, Neg); } @@ -1870,15 +1868,17 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { if (match(Op, m_Select(m_Value(Cond), m_FPExt(m_Value(X)), m_Value(Y))) && X->getType() == Ty) { // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y) - Value *NarrowY = Builder.CreateFPTrunc(Y, Ty); - Value *Sel = Builder.CreateSelect(Cond, X, NarrowY, "narrow.sel", Op); + Value *NarrowY = Builder.CreateFPTruncFMF(Y, Ty, FMF); + Value *Sel = + Builder.CreateSelectFMF(Cond, X, NarrowY, FMF, "narrow.sel", Op); return replaceInstUsesWith(FPT, Sel); } if (match(Op, m_Select(m_Value(Cond), m_Value(Y), m_FPExt(m_Value(X)))) && X->getType() == Ty) { // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X - Value *NarrowY = Builder.CreateFPTrunc(Y, Ty); - Value *Sel = Builder.CreateSelect(Cond, NarrowY, X, "narrow.sel", Op); + Value *NarrowY = Builder.CreateFPTruncFMF(Y, Ty, FMF); + Value *Sel = + Builder.CreateSelectFMF(Cond, NarrowY, X, FMF, "narrow.sel", Op); return replaceInstUsesWith(FPT, Sel); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index d6fdade..2e45725 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -747,6 +747,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ConstantExpr::getPointerBitCastOrAddrSpaceCast( cast<Constant>(RHS), Base->getType())); } else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) { + GEPNoWrapFlags NW = GEPLHS->getNoWrapFlags() & GEPRHS->getNoWrapFlags(); + // If the base pointers are different, but the indices are the same, just // compare the base pointer. if (PtrBase != GEPRHS->getOperand(0)) { @@ -764,7 +766,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If all indices are the same, just compare the base pointers. Type *BaseType = GEPLHS->getOperand(0)->getType(); - if (IndicesTheSame && CmpInst::makeCmpResultType(BaseType) == I.getType()) + if (IndicesTheSame && + CmpInst::makeCmpResultType(BaseType) == I.getType() && CanFold(NW)) return new ICmpInst(Cond, GEPLHS->getOperand(0), GEPRHS->getOperand(0)); // If we're comparing GEPs with two base pointers that only differ in type @@ -804,7 +807,6 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); } - GEPNoWrapFlags NW = GEPLHS->getNoWrapFlags() & GEPRHS->getNoWrapFlags(); if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) { // If the GEPs only differ by one index, compare it. @@ -2483,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, // icmp ule i64 (shl X, 32), 8589934592 -> // icmp ule i32 (trunc X, i32), 2 -> // icmp ult i32 (trunc X, i32), 3 - if (auto FlippedStrictness = - InstCombiner::getFlippedStrictnessPredicateAndConstant( - Pred, ConstantInt::get(ShType->getContext(), C))) { + if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant( + Pred, ConstantInt::get(ShType->getContext(), C))) { CmpPred = FlippedStrictness->first; RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue(); } @@ -3089,12 +3090,12 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, unsigned BW = C.getBitWidth(); std::bitset<4> Table; auto ComputeTable = [&](bool Op0Val, bool Op1Val) { - int Res = 0; + APInt Res(BW, 0); if (Op0Val) - Res += isa<ZExtInst>(Ext0) ? 1 : -1; + Res += APInt(BW, isa<ZExtInst>(Ext0) ? 1 : -1, /*isSigned=*/true); if (Op1Val) - Res += isa<ZExtInst>(Ext1) ? 1 : -1; - return ICmpInst::compare(APInt(BW, Res, true), C, Pred); + Res += APInt(BW, isa<ZExtInst>(Ext1) ? 1 : -1, /*isSigned=*/true); + return ICmpInst::compare(Res, C, Pred); }; Table[0] = ComputeTable(false, false); @@ -3278,8 +3279,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) { // x sgt C-1 <--> x sge C <--> not(x slt C) auto FlippedStrictness = - InstCombiner::getFlippedStrictnessPredicateAndConstant( - PredB, cast<Constant>(RHS2)); + getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2)); if (!FlippedStrictness) return false; assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && @@ -6906,79 +6906,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { return nullptr; } -std::optional<std::pair<CmpPredicate, Constant *>> -InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, - Constant *C) { - assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && - "Only for relational integer predicates."); - - Type *Type = C->getType(); - bool IsSigned = ICmpInst::isSigned(Pred); - - CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); - bool WillIncrement = - UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; - - // Check if the constant operand can be safely incremented/decremented - // without overflowing/underflowing. - auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { - return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); - }; - - Constant *SafeReplacementConstant = nullptr; - if (auto *CI = dyn_cast<ConstantInt>(C)) { - // Bail out if the constant can't be safely incremented/decremented. - if (!ConstantIsOk(CI)) - return std::nullopt; - } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) { - unsigned NumElts = FVTy->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = C->getAggregateElement(i); - if (!Elt) - return std::nullopt; - - if (isa<UndefValue>(Elt)) - continue; - - // Bail out if we can't determine if this constant is min/max or if we - // know that this constant is min/max. - auto *CI = dyn_cast<ConstantInt>(Elt); - if (!CI || !ConstantIsOk(CI)) - return std::nullopt; - - if (!SafeReplacementConstant) - SafeReplacementConstant = CI; - } - } else if (isa<VectorType>(C->getType())) { - // Handle scalable splat - Value *SplatC = C->getSplatValue(); - auto *CI = dyn_cast_or_null<ConstantInt>(SplatC); - // Bail out if the constant can't be safely incremented/decremented. - if (!CI || !ConstantIsOk(CI)) - return std::nullopt; - } else { - // ConstantExpr? - return std::nullopt; - } - - // It may not be safe to change a compare predicate in the presence of - // undefined elements, so replace those elements with the first safe constant - // that we found. - // TODO: in case of poison, it is safe; let's replace undefs only. - if (C->containsUndefOrPoisonElement()) { - assert(SafeReplacementConstant && "Replacement constant not set"); - C = Constant::replaceUndefsWith(C, SafeReplacementConstant); - } - - CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); - - // Increment or decrement the constant. - Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); - Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); - - return std::make_pair(NewPred, NewC); -} - /// If we have an icmp le or icmp ge instruction with a constant operand, turn /// it into the appropriate icmp lt or icmp gt instruction. This transform /// allows them to be folded in visitICmpInst. @@ -6994,8 +6921,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { if (!Op1C) return nullptr; - auto FlippedStrictness = - InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C); + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C); if (!FlippedStrictness) return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 3a074ee..f6992119 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -429,12 +429,12 @@ private: Value *foldBooleanAndOr(Value *LHS, Value *RHS, Instruction &I, bool IsAnd, bool IsLogical); + Value *reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, Instruction &I, + bool IsAnd, bool RHSIsLogical); + Instruction * canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i); - Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, - Instruction *CxtI, bool IsAnd, - bool IsLogical = false); Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D, bool InvertFalseVal = false); Value *getSelectCondition(Value *A, Value *B, bool ABIsTheSame); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index f85a3c9..0c34cf0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -121,21 +121,17 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(1.0), m_SpecificFP(-1.0))), - m_Value(OtherOp)))) { - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - return Builder.CreateSelect(Cond, OtherOp, Builder.CreateFNeg(OtherOp)); - } + m_Value(OtherOp)))) + return Builder.CreateSelectFMF(Cond, OtherOp, + Builder.CreateFNegFMF(OtherOp, &I), &I); // fmul (select Cond, -1.0, 1.0), OtherOp --> select Cond, -OtherOp, OtherOp // fmul OtherOp, (select Cond, -1.0, 1.0) --> select Cond, -OtherOp, OtherOp if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(-1.0), m_SpecificFP(1.0))), - m_Value(OtherOp)))) { - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - return Builder.CreateSelect(Cond, Builder.CreateFNeg(OtherOp), OtherOp); - } + m_Value(OtherOp)))) + return Builder.CreateSelectFMF(Cond, Builder.CreateFNegFMF(OtherOp, &I), + OtherOp, &I); return nullptr; } @@ -590,11 +586,9 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) { // fabs(X) / fabs(Y) --> fabs(X / Y) if (match(Op0, m_FAbs(m_Value(X))) && match(Op1, m_FAbs(m_Value(Y))) && (Op0->hasOneUse() || Op1->hasOneUse())) { - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - Value *XY = Builder.CreateBinOp(Opcode, X, Y); - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY); - Fabs->takeName(&I); + Value *XY = Builder.CreateBinOpFMF(Opcode, X, Y, &I); + Value *Fabs = + Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY, &I, I.getName()); return replaceInstUsesWith(I, Fabs); } @@ -685,8 +679,6 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { match(Op0, m_AllowReassoc(m_BinOp(Op0BinOp)))) { // Everything in this scope folds I with Op0, intersecting their FMF. FastMathFlags FMF = I.getFastMathFlags() & Op0BinOp->getFastMathFlags(); - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(FMF); Constant *C1; if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { // (C1 / X) * C --> (C * C1) / X @@ -718,7 +710,7 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { // (X + C1) * C --> (X * C) + (C * C1) if (Constant *CC1 = ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMul(X, C); + Value *XC = Builder.CreateFMulFMF(X, C, FMF); return BinaryOperator::CreateFAddFMF(XC, CC1, FMF); } } @@ -726,7 +718,7 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { // (C1 - X) * C --> (C * C1) - (X * C) if (Constant *CC1 = ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMul(X, C); + Value *XC = Builder.CreateFMulFMF(X, C, FMF); return BinaryOperator::CreateFSubFMF(CC1, XC, FMF); } } @@ -740,9 +732,7 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { FastMathFlags FMF = I.getFastMathFlags() & DivOp->getFastMathFlags(); if (FMF.allowReassoc()) { // Sink division: (X / Y) * Z --> (X * Z) / Y - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(FMF); - auto *NewFMul = Builder.CreateFMul(X, Z); + auto *NewFMul = Builder.CreateFMulFMF(X, Z, FMF); return BinaryOperator::CreateFDivFMF(NewFMul, Y, FMF); } } @@ -2066,14 +2056,18 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I, bool ShiftByX = false; // If V is not nullptr, it will be matched using m_Specific. - auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool { + auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C, + bool &PreserveNSW) -> bool { const APInt *Tmp = nullptr; if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) || (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp))))) C = *Tmp; else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) || - (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp))))) + (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp))))) { C = APInt(Tmp->getBitWidth(), 1) << *Tmp; + // We cannot preserve NSW when shifting by BW - 1. + PreserveNSW = Tmp->ult(Tmp->getBitWidth() - 1); + } if (Tmp != nullptr) return true; @@ -2095,7 +2089,9 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I, return false; }; - if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) { + bool Op0PreserveNSW = true, Op1PreserveNSW = true; + if (MatchShiftOrMulXC(Op0, X, Y, Op0PreserveNSW) && + MatchShiftOrMulXC(Op1, X, Z, Op1PreserveNSW)) { // pass } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) { ShiftByX = true; @@ -2108,7 +2104,7 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I, OverflowingBinaryOperator *BO0 = cast<OverflowingBinaryOperator>(Op0); // TODO: We may be able to deduce more about nsw/nuw of BO0/BO1 based on Y >= // Z or Z >= Y. - bool BO0HasNSW = BO0->hasNoSignedWrap(); + bool BO0HasNSW = Op0PreserveNSW && BO0->hasNoSignedWrap(); bool BO0HasNUW = BO0->hasNoUnsignedWrap(); bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW; @@ -2131,7 +2127,7 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I, }; OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1); - bool BO1HasNSW = BO1->hasNoSignedWrap(); + bool BO1HasNSW = Op1PreserveNSW && BO1->hasNoSignedWrap(); bool BO1HasNUW = BO1->hasNoUnsignedWrap(); bool BO1NoWrap = IsSRem ? BO1HasNSW : BO1HasNUW; // (rem (mul X, Y), (mul nuw/nsw X, Z)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 1fcf1c5..80308bf 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -765,30 +765,14 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { NewPN->addIncoming(InVal, PN.getIncomingBlock(0)); LoadInst *NewLI = new LoadInst(FirstLI->getType(), NewPN, "", IsVolatile, LoadAlignment); - - unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, - LLVMContext::MD_range, - LLVMContext::MD_invariant_load, - LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, - LLVMContext::MD_nonnull, - LLVMContext::MD_align, - LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null, - LLVMContext::MD_access_group, - LLVMContext::MD_noundef, - }; - - for (unsigned ID : KnownIDs) - NewLI->setMetadata(ID, FirstLI->getMetadata(ID)); + NewLI->copyMetadata(*FirstLI); // Add all operands to the new PHI and combine TBAA metadata. for (auto Incoming : drop_begin(zip(PN.blocks(), PN.incoming_values()))) { BasicBlock *BB = std::get<0>(Incoming); Value *V = std::get<1>(Incoming); LoadInst *LI = cast<LoadInst>(V); - combineMetadata(NewLI, LI, KnownIDs, true); + combineMetadataForCSE(NewLI, LI, true); Value *NewInVal = LI->getOperand(0); if (NewInVal != InVal) InVal = nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 3d251d6..1eca177 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1225,8 +1225,12 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, // zext/trunc) have one use (ending at the select), the cttz/ctlz result will // not be used if the input is zero. Relax to 'zero is poison' for that case. if (II->hasOneUse() && SelectArg->hasOneUse() && - !match(II->getArgOperand(1), m_One())) + !match(II->getArgOperand(1), m_One())) { II->setArgOperand(1, ConstantInt::getTrue(II->getContext())); + // noundef attribute on the intrinsic may no longer be valid. + II->dropUBImplyingAttrsAndMetadata(); + IC.addToWorklist(II); + } return nullptr; } @@ -1685,8 +1689,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, return nullptr; // Check the constant we'd have with flipped-strictness predicate. - auto FlippedStrictness = - InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0); + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0); if (!FlippedStrictness) return nullptr; @@ -1966,8 +1969,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal, Value *RHS; SelectPatternFlavor SPF; const DataLayout &DL = BOp->getDataLayout(); - auto Flipped = - InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1); + auto Flipped = getFlippedStrictnessPredicateAndConstant(Predicate, C1); if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, DL)) { SPF = getSelectPattern(Predicate).Flavor; @@ -2819,9 +2821,9 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, // %cnd = icmp slt i32 %rem, 0 // %add = add i32 %rem, %n // %sel = select i1 %cnd, i32 %add, i32 %rem - if (match(TrueVal, m_Add(m_Specific(RemRes), m_Value(Remainder))) && + if (match(TrueVal, m_c_Add(m_Specific(RemRes), m_Value(Remainder))) && match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) && - IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) && + IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero=*/true) && FalseVal == RemRes) return FoldToBitwiseAnd(Remainder); @@ -3769,22 +3771,9 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI, if (!SIFOp || !SIFOp->hasNoSignedZeros() || !SIFOp->hasNoNaNs()) return nullptr; - // select((fcmp Pred, X, 0), (fadd X, C), C) - // => fadd((select (fcmp Pred, X, 0), X, 0), C) - // - // Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE - Instruction *FAdd; - Constant *C; - Value *X, *Z; - CmpPredicate Pred; - - // Note: OneUse check for `Cmp` is necessary because it makes sure that other - // InstCombine folds don't undo this transformation and cause an infinite - // loop. Furthermore, it could also increase the operation count. - if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))), - m_OneUse(m_Instruction(FAdd)), m_Constant(C))) || - match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))), - m_Constant(C), m_OneUse(m_Instruction(FAdd))))) { + auto TryFoldIntoAddConstant = + [&Builder, &SI](CmpInst::Predicate Pred, Value *X, Value *Z, + Instruction *FAdd, Constant *C, bool Swapped) -> Value * { // Only these relational predicates can be transformed into maxnum/minnum // intrinsic. if (!CmpInst::isRelational(Pred) || !match(Z, m_AnyZeroFP())) @@ -3793,7 +3782,8 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI, if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C)))) return nullptr; - Value *NewSelect = Builder.CreateSelect(SI.getCondition(), X, Z, "", &SI); + Value *NewSelect = Builder.CreateSelect(SI.getCondition(), Swapped ? Z : X, + Swapped ? X : Z, "", &SI); NewSelect->takeName(&SI); Value *NewFAdd = Builder.CreateFAdd(NewSelect, C); @@ -3808,7 +3798,27 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI, cast<Instruction>(NewSelect)->setFastMathFlags(NewFMF); return NewFAdd; - } + }; + + // select((fcmp Pred, X, 0), (fadd X, C), C) + // => fadd((select (fcmp Pred, X, 0), X, 0), C) + // + // Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE + Instruction *FAdd; + Constant *C; + Value *X, *Z; + CmpPredicate Pred; + + // Note: OneUse check for `Cmp` is necessary because it makes sure that other + // InstCombine folds don't undo this transformation and cause an infinite + // loop. Furthermore, it could also increase the operation count. + if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))), + m_OneUse(m_Instruction(FAdd)), m_Constant(C)))) + return TryFoldIntoAddConstant(Pred, X, Z, FAdd, C, /*Swapped=*/false); + + if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))), + m_Constant(C), m_OneUse(m_Instruction(FAdd))))) + return TryFoldIntoAddConstant(Pred, X, Z, FAdd, C, /*Swapped=*/true); return nullptr; } @@ -3902,12 +3912,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X if (FCmp->hasOneUse() && FCmpInst::isUnordered(Pred)) { FCmpInst::Predicate InvPred = FCmp->getInversePredicate(); - IRBuilder<>::FastMathFlagGuard FMFG(Builder); // FIXME: The FMF should propagate from the select, not the fcmp. - Builder.setFastMathFlags(FCmp->getFastMathFlags()); - Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1, - FCmp->getName() + ".inv"); - Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); + Value *NewCond = Builder.CreateFCmpFMF(InvPred, Cmp0, Cmp1, FCmp, + FCmp->getName() + ".inv"); + Value *NewSel = + Builder.CreateSelectFMF(NewCond, FalseVal, TrueVal, FCmp); return replaceInstUsesWith(SI, NewSel); } } @@ -4072,15 +4081,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp; - if (CmpInst::isIntPredicate(MinMaxPred)) { + if (CmpInst::isIntPredicate(MinMaxPred)) Cmp = Builder.CreateICmp(MinMaxPred, LHS, RHS); - } else { - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - auto FMF = - cast<FPMathOperator>(SI.getCondition())->getFastMathFlags(); - Builder.setFastMathFlags(FMF); - Cmp = Builder.CreateFCmp(MinMaxPred, LHS, RHS); - } + else + Cmp = Builder.CreateFCmpFMF(MinMaxPred, LHS, RHS, + cast<Instruction>(SI.getCondition())); Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); if (!IsCastNeeded) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 934156f..2fb60ef 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -939,12 +939,11 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { m_OneUse(m_Shift(m_Value(Y), m_Value(Shift))))) return nullptr; if (!match(I.getOperand(1 - ShOpnum), - m_BinOp(m_Value(ShiftedX), m_Value(Mask)))) + m_c_BinOp(m_CombineAnd( + m_OneUse(m_Shift(m_Value(X), m_Specific(Shift))), + m_Value(ShiftedX)), + m_Value(Mask)))) return nullptr; - - if (!match(ShiftedX, m_OneUse(m_Shift(m_Value(X), m_Specific(Shift))))) - return nullptr; - // Make sure we are matching instruction shifts and not ConstantExpr auto *IY = dyn_cast<Instruction>(I.getOperand(ShOpnum)); auto *IX = dyn_cast<Instruction>(ShiftedX); @@ -1822,12 +1821,29 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN, continue; } - // If the only use of phi is comparing it with a constant then we can - // put this comparison in the incoming BB directly after a ucmp/scmp call - // because we know that it will simplify to a single icmp. - const APInt *Ignored; - if (isa<CmpIntrinsic>(InVal) && InVal->hasOneUser() && - match(&I, m_ICmp(m_Specific(PN), m_APInt(Ignored)))) { + // Handle some cases that can't be fully simplified, but where we know that + // the two instructions will fold into one. + auto WillFold = [&]() { + if (!InVal->hasOneUser()) + return false; + + // icmp of ucmp/scmp with constant will fold to icmp. + const APInt *Ignored; + if (isa<CmpIntrinsic>(InVal) && + match(&I, m_ICmp(m_Specific(PN), m_APInt(Ignored)))) + return true; + + // icmp eq zext(bool), 0 will fold to !bool. + if (isa<ZExtInst>(InVal) && + cast<ZExtInst>(InVal)->getSrcTy()->isIntOrIntVectorTy(1) && + match(&I, + m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(PN), m_Zero()))) + return true; + + return false; + }; + + if (WillFold()) { OpsToMoveUseToIncomingBB.push_back(i); NewPhiValues.push_back(nullptr); continue; @@ -2782,6 +2798,7 @@ static Instruction *foldGEPOfPhi(GetElementPtrInst &GEP, PHINode *PN, // loop iteration). if (Op1 == &GEP) return nullptr; + GEPNoWrapFlags NW = Op1->getNoWrapFlags(); int DI = -1; @@ -2838,6 +2855,8 @@ static Instruction *foldGEPOfPhi(GetElementPtrInst &GEP, PHINode *PN, } } } + + NW &= Op2->getNoWrapFlags(); } // If not all GEPs are identical we'll have to create a new PHI node. @@ -2847,6 +2866,8 @@ static Instruction *foldGEPOfPhi(GetElementPtrInst &GEP, PHINode *PN, return nullptr; auto *NewGEP = cast<GetElementPtrInst>(Op1->clone()); + NewGEP->setNoWrapFlags(NW); + if (DI == -1) { // All the GEPs feeding the PHI are identical. Clone one down into our // BB so that it can be merged with the current GEP. diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp index f9be7f9..6e86ffd 100644 --- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -61,7 +61,7 @@ enum : uint32_t { }; static cl::opt<std::string> DefaultGCOVVersion("default-gcov-version", - cl::init("408*"), cl::Hidden, + cl::init("0000"), cl::Hidden, cl::ValueRequired); static cl::opt<bool> AtomicCounter("gcov-atomic-counter", cl::Hidden, @@ -154,6 +154,7 @@ private: GCOVOptions Options; llvm::endianness Endian; raw_ostream *os; + int Version = 0; // Checksum, produced by hash of EdgeDestinations SmallVector<uint32_t, 4> FileChecksums; @@ -334,12 +335,9 @@ namespace { : GCOVRecord(P), SP(SP), EndLine(EndLine), Ident(Ident), Version(Version), EntryBlock(P, 0), ReturnBlock(P, 1) { LLVM_DEBUG(dbgs() << "Function: " << getFunctionName(SP) << "\n"); - bool ExitBlockBeforeBody = Version >= 48; - uint32_t i = ExitBlockBeforeBody ? 2 : 1; + uint32_t i = 2; for (BasicBlock &BB : *F) Blocks.insert(std::make_pair(&BB, GCOVBlock(P, i++))); - if (!ExitBlockBeforeBody) - ReturnBlock.Number = i; std::string FunctionNameAndLine; raw_string_ostream FNLOS(FunctionNameAndLine); @@ -363,44 +361,28 @@ namespace { void writeOut(uint32_t CfgChecksum) { write(GCOV_TAG_FUNCTION); SmallString<128> Filename = getFilename(SP); - uint32_t BlockLen = - 2 + (Version >= 47) + wordsOfString(getFunctionName(SP)); - if (Version < 80) - BlockLen += wordsOfString(Filename) + 1; - else - BlockLen += 1 + wordsOfString(Filename) + 3 + (Version >= 90); + uint32_t BlockLen = 3 + wordsOfString(getFunctionName(SP)); + BlockLen += 1 + wordsOfString(Filename) + 4; write(BlockLen); write(Ident); write(FuncChecksum); - if (Version >= 47) - write(CfgChecksum); + write(CfgChecksum); writeString(getFunctionName(SP)); - if (Version < 80) { - writeString(Filename); - write(SP->getLine()); - } else { - write(SP->isArtificial()); // artificial - writeString(Filename); - write(SP->getLine()); // start_line - write(0); // start_column - // EndLine is the last line with !dbg. It is not the } line as in GCC, - // but good enough. - write(EndLine); - if (Version >= 90) - write(0); // end_column - } + + write(SP->isArtificial()); // artificial + writeString(Filename); + write(SP->getLine()); // start_line + write(0); // start_column + // EndLine is the last line with !dbg. It is not the } line as in GCC, + // but good enough. + write(EndLine); + write(0); // end_column // Emit count of blocks. write(GCOV_TAG_BLOCKS); - if (Version < 80) { - write(Blocks.size() + 2); - for (int i = Blocks.size() + 2; i; --i) - write(0); - } else { - write(1); - write(Blocks.size() + 2); - } + write(1); + write(Blocks.size() + 2); LLVM_DEBUG(dbgs() << (Blocks.size() + 1) << " blocks\n"); // Emit edges between blocks. @@ -767,7 +749,6 @@ bool GCOVProfiler::emitProfileNotes( function_ref<BlockFrequencyInfo *(Function &F)> GetBFI, function_ref<BranchProbabilityInfo *(Function &F)> GetBPI, function_ref<const TargetLibraryInfo &(Function &F)> GetTLI) { - int Version; { uint8_t c3 = Options.Version[0]; uint8_t c2 = Options.Version[1]; @@ -775,6 +756,11 @@ bool GCOVProfiler::emitProfileNotes( Version = c3 >= 'A' ? (c3 - 'A') * 100 + (c2 - '0') * 10 + c1 - '0' : (c3 - '0') * 10 + c1 - '0'; } + // Emit .gcno files that are compatible with GCC 11.1. + if (Version < 111) { + Version = 111; + memcpy(Options.Version, "B11*", 4); + } bool EmitGCDA = Options.EmitData; for (unsigned i = 0, e = CUNode->getNumOperands(); i != e; ++i) { @@ -973,10 +959,8 @@ bool GCOVProfiler::emitProfileNotes( out.write(Tmp, 4); } write(Stamp); - if (Version >= 90) - writeString(""); // unuseful current_working_directory - if (Version >= 80) - write(0); // unuseful has_unexecuted_blocks + writeString("."); // unuseful current_working_directory + write(0); // unuseful has_unexecuted_blocks for (auto &Func : Funcs) Func->writeOut(Stamp); diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 530061e..2031728 100644 --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -192,7 +192,7 @@ static cl::opt<bool> cl::Hidden); static cl::opt<int> ClHotPercentileCutoff("hwasan-percentile-cutoff-hot", - cl::desc("Hot percentile cuttoff.")); + cl::desc("Hot percentile cutoff.")); static cl::opt<float> ClRandomSkipRate("hwasan-random-rate", diff --git a/llvm/lib/Transforms/Instrumentation/LowerAllowCheckPass.cpp b/llvm/lib/Transforms/Instrumentation/LowerAllowCheckPass.cpp index 2418030..f27798c 100644 --- a/llvm/lib/Transforms/Instrumentation/LowerAllowCheckPass.cpp +++ b/llvm/lib/Transforms/Instrumentation/LowerAllowCheckPass.cpp @@ -30,7 +30,7 @@ using namespace llvm; static cl::opt<int> HotPercentileCutoff("lower-allow-check-percentile-cutoff-hot", - cl::desc("Hot percentile cuttoff.")); + cl::desc("Hot percentile cutoff.")); static cl::opt<float> RandomRate("lower-allow-check-random-rate", diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 471086c..db4d62e 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -158,11 +158,11 @@ STATISTIC(NumCoveredBlocks, "Number of basic blocks that were executed"); // Command line option to specify the file to read profile from. This is // mainly used for testing. -static cl::opt<std::string> - PGOTestProfileFile("pgo-test-profile-file", cl::init(""), cl::Hidden, - cl::value_desc("filename"), - cl::desc("Specify the path of profile data file. This is" - "mainly for test purpose.")); +static cl::opt<std::string> PGOTestProfileFile( + "pgo-test-profile-file", cl::init(""), cl::Hidden, + cl::value_desc("filename"), + cl::desc("Specify the path of profile data file. This is " + "mainly for test purpose.")); static cl::opt<std::string> PGOTestProfileRemappingFile( "pgo-test-profile-remapping-file", cl::init(""), cl::Hidden, cl::value_desc("filename"), @@ -186,7 +186,7 @@ static cl::opt<unsigned> MaxNumAnnotations( // to write to the metadata for a single memop intrinsic. static cl::opt<unsigned> MaxNumMemOPAnnotations( "memop-max-annotations", cl::init(4), cl::Hidden, - cl::desc("Max number of preicise value annotations for a single memop" + cl::desc("Max number of precise value annotations for a single memop" "intrinsic")); // Command line option to control appending FunctionHash to the name of a COMDAT @@ -291,13 +291,13 @@ static cl::opt<bool> PGOVerifyHotBFI( cl::desc("Print out the non-match BFI count if a hot raw profile count " "becomes non-hot, or a cold raw profile count becomes hot. " "The print is enabled under -Rpass-analysis=pgo, or " - "internal option -pass-remakrs-analysis=pgo.")); + "internal option -pass-remarks-analysis=pgo.")); static cl::opt<bool> PGOVerifyBFI( "pgo-verify-bfi", cl::init(false), cl::Hidden, cl::desc("Print out mismatched BFI counts after setting profile metadata " "The print is enabled under -Rpass-analysis=pgo, or " - "internal option -pass-remakrs-analysis=pgo.")); + "internal option -pass-remarks-analysis=pgo.")); static cl::opt<unsigned> PGOVerifyBFIRatio( "pgo-verify-bfi-ratio", cl::init(2), cl::Hidden, diff --git a/llvm/lib/Transforms/Instrumentation/TypeSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/TypeSanitizer.cpp index 1961095..9cd81f3 100644 --- a/llvm/lib/Transforms/Instrumentation/TypeSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/TypeSanitizer.cpp @@ -70,7 +70,7 @@ namespace { /// violations. struct TypeSanitizer { TypeSanitizer(Module &M); - bool run(Function &F, const TargetLibraryInfo &TLI); + bool sanitizeFunction(Function &F, const TargetLibraryInfo &TLI); void instrumentGlobals(Module &M); private: @@ -510,7 +510,8 @@ void collectMemAccessInfo( } } -bool TypeSanitizer::run(Function &F, const TargetLibraryInfo &TLI) { +bool TypeSanitizer::sanitizeFunction(Function &F, + const TargetLibraryInfo &TLI) { // This is required to prevent instrumenting call to __tysan_init from within // the module constructor. if (&F == TysanCtorFunction.getCallee() || &F == TysanGlobalsSetTypeFunction) @@ -876,15 +877,8 @@ bool TypeSanitizer::instrumentMemInst(Value *V, Instruction *ShadowBase, return true; } -PreservedAnalyses TypeSanitizerPass::run(Function &F, - FunctionAnalysisManager &FAM) { - TypeSanitizer TySan(*F.getParent()); - TySan.run(F, FAM.getResult<TargetLibraryAnalysis>(F)); - return PreservedAnalyses::none(); -} - -PreservedAnalyses ModuleTypeSanitizerPass::run(Module &M, - ModuleAnalysisManager &AM) { +PreservedAnalyses TypeSanitizerPass::run(Module &M, + ModuleAnalysisManager &MAM) { Function *TysanCtorFunction; std::tie(TysanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions(M, kTysanModuleCtorName, @@ -894,5 +888,12 @@ PreservedAnalyses ModuleTypeSanitizerPass::run(Module &M, TypeSanitizer TySan(M); TySan.instrumentGlobals(M); appendToGlobalCtors(M, TysanCtorFunction, 0); + + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + for (Function &F : M) { + const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + TySan.sanitizeFunction(F, TLI); + } + return PreservedAnalyses::none(); } diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index ead07ed..91a3c3f 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -216,7 +216,7 @@ struct StackEntry { StackEntry(unsigned NumIn, unsigned NumOut, bool IsSigned, SmallVector<Value *, 2> ValuesToRelease) : NumIn(NumIn), NumOut(NumOut), IsSigned(IsSigned), - ValuesToRelease(ValuesToRelease) {} + ValuesToRelease(std::move(ValuesToRelease)) {} }; struct ConstraintTy { @@ -726,8 +726,8 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, } for (const auto &KV : VariablesB) { - if (SubOverflow(R[GetOrAddIndex(KV.Variable)], KV.Coefficient, - R[GetOrAddIndex(KV.Variable)])) + auto &Coeff = R[GetOrAddIndex(KV.Variable)]; + if (SubOverflow(Coeff, KV.Coefficient, Coeff)) return {}; auto I = KnownNonNegativeVariables.insert({KV.Variable, KV.IsKnownNonNegative}); @@ -759,9 +759,9 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, if (!KV.second || (!Value2Index.contains(KV.first) && !NewIndexMap.contains(KV.first))) continue; - SmallVector<int64_t, 8> C(Value2Index.size() + NewVariables.size() + 1, 0); + auto &C = Res.ExtraInfo.emplace_back( + Value2Index.size() + NewVariables.size() + 1, 0); C[GetOrAddIndex(KV.first)] = -1; - Res.ExtraInfo.push_back(C); } return Res; } @@ -1591,53 +1591,52 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B, LLVM_DEBUG(dbgs() << "Adding '"; dumpUnpackedICmp(dbgs(), Pred, A, B); dbgs() << "'\n"); - bool Added = false; auto &CSToUse = getCS(R.IsSigned); if (R.Coefficients.empty()) return; - Added |= CSToUse.addVariableRowFill(R.Coefficients); + bool Added = CSToUse.addVariableRowFill(R.Coefficients); + if (!Added) + return; // If R has been added to the system, add the new variables and queue it for // removal once it goes out-of-scope. - if (Added) { - SmallVector<Value *, 2> ValuesToRelease; - auto &Value2Index = getValue2Index(R.IsSigned); - for (Value *V : NewVariables) { - Value2Index.insert({V, Value2Index.size() + 1}); - ValuesToRelease.push_back(V); - } - - LLVM_DEBUG({ - dbgs() << " constraint: "; - dumpConstraint(R.Coefficients, getValue2Index(R.IsSigned)); - dbgs() << "\n"; - }); + SmallVector<Value *, 2> ValuesToRelease; + auto &Value2Index = getValue2Index(R.IsSigned); + for (Value *V : NewVariables) { + Value2Index.insert({V, Value2Index.size() + 1}); + ValuesToRelease.push_back(V); + } - DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, - std::move(ValuesToRelease)); - - if (!R.IsSigned) { - for (Value *V : NewVariables) { - ConstraintTy VarPos(SmallVector<int64_t, 8>(Value2Index.size() + 1, 0), - false, false, false); - VarPos.Coefficients[Value2Index[V]] = -1; - CSToUse.addVariableRow(VarPos.Coefficients); - DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, - SmallVector<Value *, 2>()); - } - } + LLVM_DEBUG({ + dbgs() << " constraint: "; + dumpConstraint(R.Coefficients, getValue2Index(R.IsSigned)); + dbgs() << "\n"; + }); - if (R.isEq()) { - // Also add the inverted constraint for equality constraints. - for (auto &Coeff : R.Coefficients) - Coeff *= -1; - CSToUse.addVariableRowFill(R.Coefficients); + DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, + std::move(ValuesToRelease)); + if (!R.IsSigned) { + for (Value *V : NewVariables) { + ConstraintTy VarPos(SmallVector<int64_t, 8>(Value2Index.size() + 1, 0), + false, false, false); + VarPos.Coefficients[Value2Index[V]] = -1; + CSToUse.addVariableRow(VarPos.Coefficients); DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, SmallVector<Value *, 2>()); } } + + if (R.isEq()) { + // Also add the inverted constraint for equality constraints. + for (auto &Coeff : R.Coefficients) + Coeff *= -1; + CSToUse.addVariableRowFill(R.Coefficients); + + DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, + SmallVector<Value *, 2>()); + } } static bool replaceSubOverflowUses(IntrinsicInst *II, Value *A, Value *B, diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index ba1c224..3c82eed 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -128,7 +128,7 @@ static cl::opt<bool, true> static cl::opt<bool> UseLIRCodeSizeHeurs( "use-lir-code-size-heurs", - cl::desc("Use loop idiom recognition code size heuristics when compiling" + cl::desc("Use loop idiom recognition code size heuristics when compiling " "with -Os/-Oz"), cl::init(true), cl::Hidden); diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 260cc72..0903488 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -104,7 +104,7 @@ static cl::opt<unsigned> UnrollMaxPercentThresholdBoost( static cl::opt<unsigned> UnrollMaxIterationsCountToAnalyze( "unroll-max-iteration-count-to-analyze", cl::init(10), cl::Hidden, - cl::desc("Don't allow loop unrolling to simulate more than this number of" + cl::desc("Don't allow loop unrolling to simulate more than this number of " "iterations when checking full unroll profitability")); static cl::opt<unsigned> UnrollCount( diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index f58dcb5..6e91c4f 100644 --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -95,7 +95,7 @@ static const char *LICMVersioningMetaData = "llvm.loop.licm_versioning.disable"; /// invariant instructions in a loop. static cl::opt<float> LVInvarThreshold("licm-versioning-invariant-threshold", - cl::desc("LoopVersioningLICM's minimum allowed percentage" + cl::desc("LoopVersioningLICM's minimum allowed percentage " "of possible invariant instructions per loop"), cl::init(25), cl::Hidden); diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index bb98b3d..5f7cb92 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -345,10 +345,14 @@ static bool writtenBetween(MemorySSA *MSSA, BatchAAResults &AA, static void combineAAMetadata(Instruction *ReplInst, Instruction *I) { // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be // handled here, but combineMetadata doesn't support them yet - unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, - LLVMContext::MD_invariant_group, - LLVMContext::MD_access_group}; + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_invariant_group, + LLVMContext::MD_access_group, LLVMContext::MD_prof, + LLVMContext::MD_memprof, LLVMContext::MD_callsite}; + // FIXME: https://github.com/llvm/llvm-project/issues/121495 + // Use custom AA metadata combining handling instead of combineMetadata, which + // is meant for CSE and will drop any metadata not in the KnownIDs list. combineMetadata(ReplInst, I, KnownIDs, true); } diff --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp index 1d4f561..b499ef8 100644 --- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp +++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp @@ -28,8 +28,8 @@ using namespace llvm; namespace llvm { cl::opt<bool> ShouldPreserveAllAttributes( "assume-preserve-all", cl::init(false), cl::Hidden, - cl::desc("enable preservation of all attrbitues. even those that are " - "unlikely to be usefull")); + cl::desc("enable preservation of all attributes. even those that are " + "unlikely to be useful")); cl::opt<bool> EnableKnowledgeRetention( "enable-knowledge-retention", cl::init(false), cl::Hidden, diff --git a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp index 47bb319..d47f1b4 100644 --- a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -48,6 +48,21 @@ static void insertCall(Function &CurFn, StringRef Func, /*isVarArg=*/false)), {GV}, "", InsertionPt); Call->setDebugLoc(DL); + } else if (TargetTriple.isRISCV() || TargetTriple.isAArch64() || + TargetTriple.isLoongArch()) { + // On RISC-V, AArch64, and LoongArch, the `_mcount` function takes + // `__builtin_return_address(0)` as an argument since + // `__builtin_return_address(1)` is not available on these platforms. + Instruction *RetAddr = CallInst::Create( + Intrinsic::getOrInsertDeclaration(&M, Intrinsic::returnaddress), + ConstantInt::get(Type::getInt32Ty(C), 0), "", InsertionPt); + RetAddr->setDebugLoc(DL); + + FunctionCallee Fn = M.getOrInsertFunction( + Func, FunctionType::get(Type::getVoidTy(C), PointerType::getUnqual(C), + false)); + CallInst *Call = CallInst::Create(Fn, RetAddr, "", InsertionPt); + Call->setDebugLoc(DL); } else { FunctionCallee Fn = M.getOrInsertFunction(Func, Type::getVoidTy(C)); CallInst *Call = CallInst::Create(Fn, "", InsertionPt); @@ -88,6 +103,12 @@ static bool runOnFunction(Function &F, bool PostInlining) { if (F.hasFnAttribute(Attribute::Naked)) return false; + // available_externally functions may not have definitions external to the + // module (e.g. gnu::always_inline). Instrumenting them might lead to linker + // errors if they are optimized out. Skip them like GCC. + if (F.hasAvailableExternallyLinkage()) + return false; + StringRef EntryAttr = PostInlining ? "instrument-function-entry-inlined" : "instrument-function-entry"; diff --git a/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp b/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp index 766c750..ae1af943 100644 --- a/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -331,15 +331,12 @@ void FunctionImportGlobalProcessing::processGlobalsForThinLTO() { } } -bool FunctionImportGlobalProcessing::run() { - processGlobalsForThinLTO(); - return false; -} +void FunctionImportGlobalProcessing::run() { processGlobalsForThinLTO(); } -bool llvm::renameModuleForThinLTO(Module &M, const ModuleSummaryIndex &Index, +void llvm::renameModuleForThinLTO(Module &M, const ModuleSummaryIndex &Index, bool ClearDSOLocalOnDeclarations, SetVector<GlobalValue *> *GlobalsToImport) { FunctionImportGlobalProcessing ThinLTOProcessing(M, Index, GlobalsToImport, ClearDSOLocalOnDeclarations); - return ThinLTOProcessing.run(); + ThinLTOProcessing.run(); } diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index a3af96d..1e4061c 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -3308,6 +3308,9 @@ bool llvm::removeUnreachableBlocks(Function &F, DomTreeUpdater *DTU, return Changed; } +// FIXME: https://github.com/llvm/llvm-project/issues/121495 +// Once external callers of this function are removed, either inline into +// combineMetadataForCSE, or internalize and remove KnownIDs parameter. void llvm::combineMetadata(Instruction *K, const Instruction *J, ArrayRef<unsigned> KnownIDs, bool DoesKMove) { SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; @@ -3320,6 +3323,10 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, switch (Kind) { default: + // FIXME: https://github.com/llvm/llvm-project/issues/121495 + // Change to removing only explicitly listed other metadata, and assert + // on unknown metadata, to avoid inadvertently dropping newly added + // metadata types. K->setMetadata(Kind, nullptr); // Remove unknown metadata break; case LLVMContext::MD_dbg: @@ -3379,6 +3386,12 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, K->setMetadata(Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); break; + case LLVMContext::MD_memprof: + K->setMetadata(Kind, MDNode::getMergedMemProfMetadata(KMD, JMD)); + break; + case LLVMContext::MD_callsite: + K->setMetadata(Kind, MDNode::getMergedCallsiteMetadata(KMD, JMD)); + break; case LLVMContext::MD_preserve_access_index: // Preserve !preserve.access.index in K. break; @@ -3442,7 +3455,9 @@ void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J, LLVMContext::MD_nontemporal, LLVMContext::MD_noundef, LLVMContext::MD_mmra, - LLVMContext::MD_noalias_addrspace}; + LLVMContext::MD_noalias_addrspace, + LLVMContext::MD_memprof, + LLVMContext::MD_callsite}; combineMetadata(K, J, KnownIDs, KDominatesJ); } diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp index d829864..b3f9f76 100644 --- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -778,7 +778,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", "Canonicalize natural loops", - false, true) + false, false) // Publicly exposed interface to pass... char &llvm::LoopSimplifyID = LoopSimplify::ID; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index febc568..e367b01 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -96,8 +96,9 @@ using namespace PatternMatch; cl::opt<bool> llvm::RequireAndPreserveDomTree( "simplifycfg-require-and-preserve-domtree", cl::Hidden, - cl::desc("Temorary development switch used to gradually uplift SimplifyCFG " - "into preserving DomTree,")); + cl::desc( + "Temporary development switch used to gradually uplift SimplifyCFG " + "into preserving DomTree,")); // Chosen as 2 so as to be cheap, but still to have enough power to fold // a select, so the "clamp" idiom (of a min followed by a max) will be caught. @@ -126,7 +127,7 @@ static cl::opt<bool> HoistLoadsStoresWithCondFaulting( static cl::opt<unsigned> HoistLoadsStoresWithCondFaultingThreshold( "hoist-loads-stores-with-cond-faulting-threshold", cl::Hidden, cl::init(6), - cl::desc("Control the maximal conditonal load/store that we are willing " + cl::desc("Control the maximal conditional load/store that we are willing " "to speculatively execute to eliminate conditional branch " "(default = 6)")); @@ -2153,12 +2154,9 @@ bool SimplifyCFGOpt::hoistSuccIdenticalTerminatorToSwitchOrIf( SelectInst *&SI = InsertedSelects[std::make_pair(BB1V, BB2V)]; if (!SI) { // Propagate fast-math-flags from phi node to its replacement select. - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - if (isa<FPMathOperator>(PN)) - Builder.setFastMathFlags(PN.getFastMathFlags()); - - SI = cast<SelectInst>(Builder.CreateSelect( + SI = cast<SelectInst>(Builder.CreateSelectFMF( BI->getCondition(), BB1V, BB2V, + isa<FPMathOperator>(PN) ? &PN : nullptr, BB1V->getName() + "." + BB2V->getName(), BI)); } @@ -3898,16 +3896,14 @@ static bool foldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, IRBuilder<NoFolder> Builder(DomBI); // Propagate fast-math-flags from phi nodes to replacement selects. - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); while (PHINode *PN = dyn_cast<PHINode>(BB->begin())) { - if (isa<FPMathOperator>(PN)) - Builder.setFastMathFlags(PN->getFastMathFlags()); - // Change the PHI node into a select instruction. Value *TrueVal = PN->getIncomingValueForBlock(IfTrue); Value *FalseVal = PN->getIncomingValueForBlock(IfFalse); - Value *Sel = Builder.CreateSelect(IfCond, TrueVal, FalseVal, "", DomBI); + Value *Sel = Builder.CreateSelectFMF(IfCond, TrueVal, FalseVal, + isa<FPMathOperator>(PN) ? PN : nullptr, + "", DomBI); PN->replaceAllUsesWith(Sel); Sel->takeName(PN); PN->eraseFromParent(); diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 737818b..2b2b467 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -2005,28 +2005,21 @@ Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilderBase &B) { AbsOp = Real; } - if (AbsOp) { - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - + if (AbsOp) return copyFlags( - *CI, B.CreateUnaryIntrinsic(Intrinsic::fabs, AbsOp, nullptr, "cabs")); - } + *CI, B.CreateUnaryIntrinsic(Intrinsic::fabs, AbsOp, CI, "cabs")); if (!CI->isFast()) return nullptr; } // Propagate fast-math flags from the existing call to new instructions. - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - Value *RealReal = B.CreateFMul(Real, Real); - Value *ImagImag = B.CreateFMul(Imag, Imag); - - return copyFlags(*CI, B.CreateUnaryIntrinsic(Intrinsic::sqrt, - B.CreateFAdd(RealReal, ImagImag), - nullptr, "cabs")); + Value *RealReal = B.CreateFMulFMF(Real, Real, CI); + Value *ImagImag = B.CreateFMulFMF(Imag, Imag, CI); + return copyFlags( + *CI, B.CreateUnaryIntrinsic(Intrinsic::sqrt, + B.CreateFAddFMF(RealReal, ImagImag, CI), CI, + "cabs")); } // Return a properly extended integer (DstWidth bits wide) if the operation is @@ -2480,15 +2473,13 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) { // "Ideally, fmax would be sensitive to the sign of zero, for example // fmax(-0.0, +0.0) would return +0; however, implementation in software // might be impractical." - IRBuilderBase::FastMathFlagGuard Guard(B); FastMathFlags FMF = CI->getFastMathFlags(); FMF.setNoSignedZeros(); - B.setFastMathFlags(FMF); Intrinsic::ID IID = Callee->getName().starts_with("fmin") ? Intrinsic::minnum : Intrinsic::maxnum; return copyFlags(*CI, B.CreateBinaryIntrinsic(IID, CI->getArgOperand(0), - CI->getArgOperand(1))); + CI->getArgOperand(1), FMF)); } Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { @@ -2783,20 +2774,18 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { // Fast math flags for any created instructions should match the sqrt // and multiply. - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(I->getFastMathFlags()); // If we found a repeated factor, hoist it out of the square root and // replace it with the fabs of that factor. Value *FabsCall = - B.CreateUnaryIntrinsic(Intrinsic::fabs, RepeatOp, nullptr, "fabs"); + B.CreateUnaryIntrinsic(Intrinsic::fabs, RepeatOp, I, "fabs"); if (OtherOp) { // If we found a non-repeated factor, we still need to get its square // root. We then multiply that by the value that was simplified out // of the square root calculation. Value *SqrtCall = - B.CreateUnaryIntrinsic(Intrinsic::sqrt, OtherOp, nullptr, "sqrt"); - return copyFlags(*CI, B.CreateFMul(FabsCall, SqrtCall)); + B.CreateUnaryIntrinsic(Intrinsic::sqrt, OtherOp, I, "sqrt"); + return copyFlags(*CI, B.CreateFMulFMF(FabsCall, SqrtCall, I)); } return copyFlags(*CI, FabsCall); } @@ -2951,26 +2940,23 @@ static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven, Value *Src = CI->getArgOperand(0); if (match(Src, m_OneUse(m_FNeg(m_Value(X))))) { - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X})); + auto *Call = B.CreateCall(CI->getCalledFunction(), {X}); + Call->copyFastMathFlags(CI); + auto *CallInst = copyFlags(*CI, Call); if (IsEven) { // Even function: f(-x) = f(x) return CallInst; } // Odd function: f(-x) = -f(x) - return B.CreateFNeg(CallInst); + return B.CreateFNegFMF(CallInst, CI); } // Even function: f(abs(x)) = f(x), f(copysign(x, y)) = f(x) if (IsEven && (match(Src, m_FAbs(m_Value(X))) || match(Src, m_CopySign(m_Value(X), m_Value())))) { - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X})); - return CallInst; + auto *Call = B.CreateCall(CI->getCalledFunction(), {X}); + Call->copyFastMathFlags(CI); + return copyFlags(*CI, Call); } return nullptr; diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 02ec1d5..9e81573 100644 --- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -324,6 +324,11 @@ private: Instruction *ChainElem, Instruction *ChainBegin, const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets); + /// Merges the equivalence classes if they have underlying objects that differ + /// by one level of indirection (i.e., one is a getelementptr and the other is + /// the base pointer in that getelementptr). + void mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const; + /// Collects loads and stores grouped by "equivalence class", where: /// - all elements in an eq class are a load or all are a store, /// - they all load/store the same element size (it's OK to have e.g. i8 and @@ -1305,6 +1310,119 @@ std::optional<APInt> Vectorizer::getConstantOffsetSelects( return std::nullopt; } +void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const { + if (EQClasses.size() < 2) // There is nothing to merge. + return; + + // The reduced key has all elements of the ECClassKey except the underlying + // object. Check that EqClassKey has 4 elements and define the reduced key. + static_assert(std::tuple_size_v<EqClassKey> == 4, + "EqClassKey has changed - EqClassReducedKey needs changes too"); + using EqClassReducedKey = + std::tuple<std::tuple_element_t<1, EqClassKey> /* AddrSpace */, + std::tuple_element_t<2, EqClassKey> /* Element size */, + std::tuple_element_t<3, EqClassKey> /* IsLoad; */>; + using ECReducedKeyToUnderlyingObjectMap = + MapVector<EqClassReducedKey, + SmallPtrSet<std::tuple_element_t<0, EqClassKey>, 4>>; + + // Form a map from the reduced key (without the underlying object) to the + // underlying objects: 1 reduced key to many underlying objects, to form + // groups of potentially merge-able equivalence classes. + ECReducedKeyToUnderlyingObjectMap RedKeyToUOMap; + bool FoundPotentiallyOptimizableEC = false; + for (const auto &EC : EQClasses) { + const auto &Key = EC.first; + EqClassReducedKey RedKey{std::get<1>(Key), std::get<2>(Key), + std::get<3>(Key)}; + RedKeyToUOMap[RedKey].insert(std::get<0>(Key)); + if (RedKeyToUOMap[RedKey].size() > 1) + FoundPotentiallyOptimizableEC = true; + } + if (!FoundPotentiallyOptimizableEC) + return; + + LLVM_DEBUG({ + dbgs() << "LSV: mergeEquivalenceClasses: before merging:\n"; + for (const auto &EC : EQClasses) { + dbgs() << " Key: {" << EC.first << "}\n"; + for (const auto &Inst : EC.second) + dbgs() << " Inst: " << *Inst << '\n'; + } + }); + LLVM_DEBUG({ + dbgs() << "LSV: mergeEquivalenceClasses: RedKeyToUOMap:\n"; + for (const auto &RedKeyToUO : RedKeyToUOMap) { + dbgs() << " Reduced key: {" << std::get<0>(RedKeyToUO.first) << ", " + << std::get<1>(RedKeyToUO.first) << ", " + << static_cast<int>(std::get<2>(RedKeyToUO.first)) << "} --> " + << RedKeyToUO.second.size() << " underlying objects:\n"; + for (auto UObject : RedKeyToUO.second) + dbgs() << " " << *UObject << '\n'; + } + }); + + using UObjectToUObjectMap = DenseMap<const Value *, const Value *>; + + // Compute the ultimate targets for a set of underlying objects. + auto GetUltimateTargets = + [](SmallPtrSetImpl<const Value *> &UObjects) -> UObjectToUObjectMap { + UObjectToUObjectMap IndirectionMap; + for (const auto *UObject : UObjects) { + const unsigned MaxLookupDepth = 1; // look for 1-level indirections only + const auto *UltimateTarget = getUnderlyingObject(UObject, MaxLookupDepth); + if (UltimateTarget != UObject) + IndirectionMap[UObject] = UltimateTarget; + } + UObjectToUObjectMap UltimateTargetsMap; + for (const auto *UObject : UObjects) { + auto Target = UObject; + auto It = IndirectionMap.find(Target); + for (; It != IndirectionMap.end(); It = IndirectionMap.find(Target)) + Target = It->second; + UltimateTargetsMap[UObject] = Target; + } + return UltimateTargetsMap; + }; + + // For each item in RedKeyToUOMap, if it has more than one underlying object, + // try to merge the equivalence classes. + for (auto &[RedKey, UObjects] : RedKeyToUOMap) { + if (UObjects.size() < 2) + continue; + auto UTMap = GetUltimateTargets(UObjects); + for (const auto &[UObject, UltimateTarget] : UTMap) { + if (UObject == UltimateTarget) + continue; + + EqClassKey KeyFrom{UObject, std::get<0>(RedKey), std::get<1>(RedKey), + std::get<2>(RedKey)}; + EqClassKey KeyTo{UltimateTarget, std::get<0>(RedKey), std::get<1>(RedKey), + std::get<2>(RedKey)}; + // The entry for KeyFrom is guarantted to exist, unlike KeyTo. Thus, + // request the reference to the instructions vector for KeyTo first. + const auto &VecTo = EQClasses[KeyTo]; + const auto &VecFrom = EQClasses[KeyFrom]; + SmallVector<Instruction *, 8> MergedVec; + std::merge(VecFrom.begin(), VecFrom.end(), VecTo.begin(), VecTo.end(), + std::back_inserter(MergedVec), + [](Instruction *A, Instruction *B) { + return A && B && A->comesBefore(B); + }); + EQClasses[KeyTo] = std::move(MergedVec); + EQClasses.erase(KeyFrom); + } + } + LLVM_DEBUG({ + dbgs() << "LSV: mergeEquivalenceClasses: after merging:\n"; + for (const auto &EC : EQClasses) { + dbgs() << " Key: {" << EC.first << "}\n"; + for (const auto &Inst : EC.second) + dbgs() << " Inst: " << *Inst << '\n'; + } + }); +} + EquivalenceClassMap Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin, BasicBlock::iterator End) { @@ -1377,6 +1495,7 @@ Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin, .emplace_back(&I); } + mergeEquivalenceClasses(Ret); return Ret; } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 650a485..bc44ec1 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -231,17 +231,21 @@ public: new VPInstruction(Ptr, Offset, GEPNoWrapFlags::inBounds(), DL, Name)); } + /// Convert the input value \p Current to the corresponding value of an + /// induction with \p Start and \p Step values, using \p Start + \p Current * + /// \p Step. VPDerivedIVRecipe *createDerivedIV(InductionDescriptor::InductionKind Kind, FPMathOperator *FPBinOp, VPValue *Start, - VPCanonicalIVPHIRecipe *CanonicalIV, - VPValue *Step, const Twine &Name = "") { + VPValue *Current, VPValue *Step, + const Twine &Name = "") { return tryInsertInstruction( - new VPDerivedIVRecipe(Kind, FPBinOp, Start, CanonicalIV, Step, Name)); + new VPDerivedIVRecipe(Kind, FPBinOp, Start, Current, Step, Name)); } VPScalarCastRecipe *createScalarCast(Instruction::CastOps Opcode, VPValue *Op, - Type *ResultTy) { - return tryInsertInstruction(new VPScalarCastRecipe(Opcode, Op, ResultTy)); + Type *ResultTy, DebugLoc DL) { + return tryInsertInstruction( + new VPScalarCastRecipe(Opcode, Op, ResultTy, DL)); } VPWidenCastRecipe *createWidenCast(Instruction::CastOps Opcode, VPValue *Op, diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index af6fce4..47866da 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -479,7 +479,8 @@ public: AC(AC), ORE(ORE), VF(VecWidth), MinProfitableTripCount(MinProfitableTripCount), UF(UnrollFactor), Builder(PSE.getSE()->getContext()), Legal(LVL), Cost(CM), BFI(BFI), - PSI(PSI), RTChecks(RTChecks), Plan(Plan) { + PSI(PSI), RTChecks(RTChecks), Plan(Plan), + VectorPHVPB(Plan.getEntry()->getSingleSuccessor()) { // Query this against the original loop and save it here because the profile // of the original loop header may change as the transformation happens. OptForSizeBasedOnProfile = llvm::shouldOptimizeForSize( @@ -517,22 +518,6 @@ public: /// Fix the non-induction PHIs in \p Plan. void fixNonInductionPHIs(VPTransformState &State); - /// Create a ResumePHI VPInstruction for the induction \p InductionPhiIRI to - /// resume iteration count in the scalar epilogue from where the vectorized - /// loop left off, and add it to the scalar preheader of VPlan. Also creates - /// the induction resume value, and the value for the bypass block, if needed. - /// \p Step is the SCEV-expanded induction step to use. In cases where the - /// loop skeleton is more complicated (i.e., epilogue vectorization) and the - /// resume values can come from an additional bypass block, - /// \p MainVectorTripCount provides the trip count of the main vector loop, - /// used to compute the resume value reaching the scalar loop preheader - /// directly from this additional bypass block. - void createInductionResumeVPValue(VPIRInstruction *InductionPhiIRI, - const InductionDescriptor &ID, Value *Step, - ArrayRef<BasicBlock *> BypassBlocks, - VPBuilder &ScalarPHBuilder, - Value *MainVectorTripCount = nullptr); - /// Returns the original loop trip count. Value *getTripCount() const { return TripCount; } @@ -588,23 +573,21 @@ protected: /// vector loop preheader, middle block and scalar preheader. void createVectorLoopSkeleton(StringRef Prefix); - /// Create new phi nodes for the induction variables to resume iteration count - /// in the scalar epilogue, from where the vectorized loop left off. - /// In cases where the loop skeleton is more complicated (i.e. epilogue - /// vectorization), \p MainVectorTripCount provides the trip count of the main - /// loop, used to compute these resume values. If \p IVSubset is provided, it - /// contains the phi nodes for which resume values are needed, because they - /// will generate wide induction phis in the epilogue loop. - void - createInductionResumeVPValues(const SCEV2ValueTy &ExpandedSCEVs, - Value *MainVectorTripCount = nullptr, - SmallPtrSetImpl<PHINode *> *IVSubset = nullptr); + /// Create and record the values for induction variables to resume coming from + /// the additional bypass block. + void createInductionAdditionalBypassValues(const SCEV2ValueTy &ExpandedSCEVs, + Value *MainVectorTripCount); /// Allow subclasses to override and print debug traces before/after vplan /// execution, when trace information is requested. virtual void printDebugTracesAtStart() {} virtual void printDebugTracesAtEnd() {} + /// Introduces a new VPIRBasicBlock for \p CheckIRBB to Plan between the + /// vector preheader and its predecessor, also connecting the new block to the + /// scalar preheader. + void introduceCheckBlockInVPlan(BasicBlock *CheckIRBB); + /// The original loop. Loop *OrigLoop; @@ -699,6 +682,10 @@ protected: BasicBlock *AdditionalBypassBlock = nullptr; VPlan &Plan; + + /// The vector preheader block of \p Plan, used as target for check blocks + /// introduced during skeleton creation. + VPBlockBase *VectorPHVPB; }; /// Encapsulate information regarding vectorization of a loop and its epilogue. @@ -1744,7 +1731,8 @@ private: bool needsExtract(Value *V, ElementCount VF) const { Instruction *I = dyn_cast<Instruction>(V); if (VF.isScalar() || !I || !TheLoop->contains(I) || - TheLoop->isLoopInvariant(I)) + TheLoop->isLoopInvariant(I) || + getWideningDecision(I, VF) == CM_Scalarize) return false; // Assume we can vectorize V (and hence we need extraction) if the @@ -2406,12 +2394,12 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr, // End if-block. VPRegionBlock *Parent = RepRecipe->getParent()->getParent(); bool IfPredicateInstr = Parent ? Parent->isReplicator() : false; - assert((Parent || all_of(RepRecipe->operands(), - [](VPValue *Op) { - return Op->isDefinedOutsideLoopRegions(); - })) && - "Expected a recipe is either within a region or all of its operands " - "are defined outside the vectorized region."); + assert( + (Parent || !RepRecipe->getParent()->getPlan()->getVectorLoopRegion() || + all_of(RepRecipe->operands(), + [](VPValue *Op) { return Op->isDefinedOutsideLoopRegions(); })) && + "Expected a recipe is either within a region or all of its operands " + "are defined outside the vectorized region."); if (IfPredicateInstr) PredicatedInstructions.push_back(Cloned); } @@ -2466,19 +2454,15 @@ InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) { return VectorTripCount; } -/// Introduces a new VPIRBasicBlock for \p CheckIRBB to \p Plan between the -/// vector preheader and its predecessor, also connecting the new block to the -/// scalar preheader. -static void introduceCheckBlockInVPlan(VPlan &Plan, BasicBlock *CheckIRBB) { +void InnerLoopVectorizer::introduceCheckBlockInVPlan(BasicBlock *CheckIRBB) { VPBlockBase *ScalarPH = Plan.getScalarPreheader(); - VPBlockBase *VectorPH = Plan.getVectorPreheader(); - VPBlockBase *PreVectorPH = VectorPH->getSinglePredecessor(); + VPBlockBase *PreVectorPH = VectorPHVPB->getSinglePredecessor(); if (PreVectorPH->getNumSuccessors() != 1) { assert(PreVectorPH->getNumSuccessors() == 2 && "Expected 2 successors"); assert(PreVectorPH->getSuccessors()[0] == ScalarPH && "Unexpected successor"); - VPIRBasicBlock *CheckVPIRBB = VPIRBasicBlock::fromBasicBlock(CheckIRBB); - VPBlockUtils::insertOnEdge(PreVectorPH, VectorPH, CheckVPIRBB); + VPIRBasicBlock *CheckVPIRBB = Plan.createVPIRBasicBlock(CheckIRBB); + VPBlockUtils::insertOnEdge(PreVectorPH, VectorPHVPB, CheckVPIRBB); PreVectorPH = CheckVPIRBB; } VPBlockUtils::connectBlocks(PreVectorPH, ScalarPH); @@ -2567,7 +2551,7 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { LoopBypassBlocks.push_back(TCCheckBlock); // TODO: Wrap LoopVectorPreHeader in VPIRBasicBlock here. - introduceCheckBlockInVPlan(Plan, TCCheckBlock); + introduceCheckBlockInVPlan(TCCheckBlock); } BasicBlock *InnerLoopVectorizer::emitSCEVChecks(BasicBlock *Bypass) { @@ -2585,7 +2569,7 @@ BasicBlock *InnerLoopVectorizer::emitSCEVChecks(BasicBlock *Bypass) { LoopBypassBlocks.push_back(SCEVCheckBlock); AddedSafetyChecks = true; - introduceCheckBlockInVPlan(Plan, SCEVCheckBlock); + introduceCheckBlockInVPlan(SCEVCheckBlock); return SCEVCheckBlock; } @@ -2622,10 +2606,25 @@ BasicBlock *InnerLoopVectorizer::emitMemRuntimeChecks(BasicBlock *Bypass) { AddedSafetyChecks = true; - introduceCheckBlockInVPlan(Plan, MemCheckBlock); + introduceCheckBlockInVPlan(MemCheckBlock); return MemCheckBlock; } +/// Replace \p VPBB with a VPIRBasicBlock wrapping \p IRBB. All recipes from \p +/// VPBB are moved to the end of the newly created VPIRBasicBlock. VPBB must +/// have a single predecessor, which is rewired to the new VPIRBasicBlock. All +/// successors of VPBB, if any, are rewired to the new VPIRBasicBlock. +static void replaceVPBBWithIRVPBB(VPBasicBlock *VPBB, BasicBlock *IRBB) { + VPIRBasicBlock *IRVPBB = VPBB->getPlan()->createVPIRBasicBlock(IRBB); + for (auto &R : make_early_inc_range(*VPBB)) { + assert(!R.isPhi() && "Tried to move phi recipe to end of block"); + R.moveBefore(*IRVPBB, IRVPBB->end()); + } + + VPBlockUtils::reassociateBlocks(VPBB, IRVPBB); + // VPBB is now dead and will be cleaned up when the plan gets destroyed. +} + void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { LoopVectorPreHeader = OrigLoop->getLoopPreheader(); assert(LoopVectorPreHeader && "Invalid loop structure"); @@ -2636,64 +2635,11 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { LoopMiddleBlock = SplitBlock(LoopVectorPreHeader, LoopVectorPreHeader->getTerminator(), DT, LI, nullptr, Twine(Prefix) + "middle.block"); + replaceVPBBWithIRVPBB(Plan.getMiddleBlock(), LoopMiddleBlock); LoopScalarPreHeader = SplitBlock(LoopMiddleBlock, LoopMiddleBlock->getTerminator(), DT, LI, nullptr, Twine(Prefix) + "scalar.ph"); -} - -void InnerLoopVectorizer::createInductionResumeVPValue( - VPIRInstruction *InductionPhiRI, const InductionDescriptor &II, Value *Step, - ArrayRef<BasicBlock *> BypassBlocks, VPBuilder &ScalarPHBuilder, - Value *MainVectorTripCount) { - // TODO: Move to LVP or general VPlan construction, once no IR values are - // generated. - auto *OrigPhi = cast<PHINode>(&InductionPhiRI->getInstruction()); - Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); - assert(VectorTripCount && "Expected valid arguments"); - - Instruction *OldInduction = Legal->getPrimaryInduction(); - // For the primary induction the end values are known. - Value *EndValue = VectorTripCount; - Value *EndValueFromAdditionalBypass = MainVectorTripCount; - // Otherwise compute them accordingly. - if (OrigPhi != OldInduction) { - IRBuilder<> B(LoopVectorPreHeader->getTerminator()); - - // Fast-math-flags propagate from the original induction instruction. - if (isa_and_nonnull<FPMathOperator>(II.getInductionBinOp())) - B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags()); - - EndValue = emitTransformedIndex(B, VectorTripCount, II.getStartValue(), - Step, II.getKind(), II.getInductionBinOp()); - EndValue->setName("ind.end"); - - // Compute the end value for the additional bypass (if applicable). - if (MainVectorTripCount) { - B.SetInsertPoint(getAdditionalBypassBlock(), - getAdditionalBypassBlock()->getFirstInsertionPt()); - EndValueFromAdditionalBypass = - emitTransformedIndex(B, MainVectorTripCount, II.getStartValue(), Step, - II.getKind(), II.getInductionBinOp()); - EndValueFromAdditionalBypass->setName("ind.end"); - } - } - - auto *ResumePhiRecipe = ScalarPHBuilder.createNaryOp( - VPInstruction::ResumePhi, - {Plan.getOrAddLiveIn(EndValue), Plan.getOrAddLiveIn(II.getStartValue())}, - OrigPhi->getDebugLoc(), "bc.resume.val"); - assert(InductionPhiRI->getNumOperands() == 0 && - "InductionPhiRI should not have any operands"); - InductionPhiRI->addOperand(ResumePhiRecipe); - - if (EndValueFromAdditionalBypass) { - // Store the bypass value here, as it needs to be added as operand to its - // scalar preheader phi node after the epilogue skeleton has been created. - // TODO: Directly add as extra operand to the VPResumePHI recipe. - assert(!Induction2AdditionalBypassValue.contains(OrigPhi) && - "entry for OrigPhi already exits"); - Induction2AdditionalBypassValue[OrigPhi] = EndValueFromAdditionalBypass; - } + replaceVPBBWithIRVPBB(Plan.getScalarPreheader(), LoopScalarPreHeader); } /// Return the expanded step for \p ID using \p ExpandedSCEVs to look up SCEV @@ -2733,46 +2679,40 @@ static void addFullyUnrolledInstructionsToIgnore( } } -void InnerLoopVectorizer::createInductionResumeVPValues( - const SCEV2ValueTy &ExpandedSCEVs, Value *MainVectorTripCount, - SmallPtrSetImpl<PHINode *> *IVSubset) { - // We are going to resume the execution of the scalar loop. - // Go over all of the induction variable PHIs of the scalar loop header and - // fix their starting values, which depend on the counter of the last - // iteration of the vectorized loop. If we come from one of the - // LoopBypassBlocks then we need to start from the original start value. - // Otherwise we provide the trip count from the main vector loop. - VPBasicBlock *ScalarPHVPBB = Plan.getScalarPreheader(); - VPBuilder ScalarPHBuilder(ScalarPHVPBB, ScalarPHVPBB->begin()); - bool HasCanonical = false; - for (VPRecipeBase &R : *Plan.getScalarHeader()) { - auto *PhiR = cast<VPIRInstruction>(&R); - auto *Phi = dyn_cast<PHINode>(&PhiR->getInstruction()); - if (!Phi) - break; - if (!Legal->getInductionVars().contains(Phi) || - (IVSubset && !IVSubset->contains(Phi))) - continue; - const InductionDescriptor &II = Legal->getInductionVars().find(Phi)->second; - createInductionResumeVPValue(PhiR, II, getExpandedStep(II, ExpandedSCEVs), - LoopBypassBlocks, ScalarPHBuilder, - MainVectorTripCount); - auto *ConstStart = dyn_cast<ConstantInt>(II.getStartValue()); - auto *ConstStep = II.getConstIntStepValue(); - if (Phi->getType() == VectorTripCount->getType() && ConstStart && - ConstStart->isZero() && ConstStep && ConstStep->isOne()) - HasCanonical = true; - } - - if (!IVSubset || HasCanonical) - return; - // When vectorizing the epilogue, create a resume phi for the canonical IV if - // no suitable resume phi was already created. - ScalarPHBuilder.createNaryOp( - VPInstruction::ResumePhi, - {&Plan.getVectorTripCount(), - Plan.getOrAddLiveIn(ConstantInt::get(VectorTripCount->getType(), 0))}, - {}, "vec.epilog.resume.val"); +void InnerLoopVectorizer::createInductionAdditionalBypassValues( + const SCEV2ValueTy &ExpandedSCEVs, Value *MainVectorTripCount) { + assert(MainVectorTripCount && "Must have bypass information"); + + Instruction *OldInduction = Legal->getPrimaryInduction(); + IRBuilder<> BypassBuilder(getAdditionalBypassBlock(), + getAdditionalBypassBlock()->getFirstInsertionPt()); + for (const auto &InductionEntry : Legal->getInductionVars()) { + PHINode *OrigPhi = InductionEntry.first; + const InductionDescriptor &II = InductionEntry.second; + Value *Step = getExpandedStep(II, ExpandedSCEVs); + // For the primary induction the additional bypass end value is known. + // Otherwise it is computed. + Value *EndValueFromAdditionalBypass = MainVectorTripCount; + if (OrigPhi != OldInduction) { + auto *BinOp = II.getInductionBinOp(); + // Fast-math-flags propagate from the original induction instruction. + if (isa_and_nonnull<FPMathOperator>(BinOp)) + BypassBuilder.setFastMathFlags(BinOp->getFastMathFlags()); + + // Compute the end value for the additional bypass. + EndValueFromAdditionalBypass = + emitTransformedIndex(BypassBuilder, MainVectorTripCount, + II.getStartValue(), Step, II.getKind(), BinOp); + EndValueFromAdditionalBypass->setName("ind.end"); + } + + // Store the bypass value here, as it needs to be added as operand to its + // scalar preheader phi node after the epilogue skeleton has been created. + // TODO: Directly add as extra operand to the VPResumePHI recipe. + assert(!Induction2AdditionalBypassValue.contains(OrigPhi) && + "entry for OrigPhi already exits"); + Induction2AdditionalBypassValue[OrigPhi] = EndValueFromAdditionalBypass; + } } BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton( @@ -2832,9 +2772,6 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton( // faster. emitMemRuntimeChecks(LoopScalarPreHeader); - // Emit phis for the new starting index of the scalar loop. - createInductionResumeVPValues(ExpandedSCEVs); - return LoopVectorPreHeader; } @@ -3048,22 +2985,6 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) { PSE.getSE()->forgetLoop(OrigLoop); PSE.getSE()->forgetBlockAndLoopDispositions(); - // When dealing with uncountable early exits we create middle.split blocks - // between the vector loop region and the exit block. These blocks need - // adding to any outer loop. - VPRegionBlock *VectorRegion = State.Plan->getVectorLoopRegion(); - Loop *OuterLoop = OrigLoop->getParentLoop(); - if (Legal->hasUncountableEarlyExit() && OuterLoop) { - VPBasicBlock *MiddleVPBB = State.Plan->getMiddleBlock(); - VPBlockBase *PredVPBB = MiddleVPBB->getSinglePredecessor(); - while (PredVPBB && PredVPBB != VectorRegion) { - BasicBlock *MiddleSplitBB = - State.CFG.VPBB2IRBB[cast<VPBasicBlock>(PredVPBB)]; - OuterLoop->addBasicBlockToLoop(MiddleSplitBB, *LI); - PredVPBB = PredVPBB->getSinglePredecessor(); - } - } - // After vectorization, the exit blocks of the original loop will have // additional predecessors. Invalidate SCEVs for the exit phis in case SE // looked through single-entry phis. @@ -3091,9 +3012,15 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) { getOrCreateVectorTripCount(nullptr), LoopMiddleBlock, State); } + // Don't apply optimizations below when no vector region remains, as they all + // require a vector loop at the moment. + if (!State.Plan->getVectorLoopRegion()) + return; + for (Instruction *PI : PredicatedInstructions) sinkScalarOperands(&*PI); + VPRegionBlock *VectorRegion = State.Plan->getVectorLoopRegion(); VPBasicBlock *HeaderVPBB = VectorRegion->getEntryBasicBlock(); BasicBlock *HeaderBB = State.CFG.VPBB2IRBB[HeaderVPBB]; @@ -3576,10 +3503,10 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( if (hasIrregularType(ScalarTy, DL)) return false; - // For scalable vectors, the only interleave factor currently supported - // must be power of 2 since we require the (de)interleave2 intrinsics - // instead of shufflevectors. - if (VF.isScalable() && !isPowerOf2_32(InterleaveFactor)) + // We currently only know how to emit interleave/deinterleave with + // Factor=2 for scalable vectors. This is purely an implementation + // limit. + if (VF.isScalable() && InterleaveFactor != 2) return false; // If the group involves a non-integral pointer, we may not be able to @@ -4768,7 +4695,6 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() { !isMoreProfitable(ChosenFactor, ScalarCost)) dbgs() << "LV: Vectorization seems to be not beneficial, " << "but was forced by a user.\n"); - LLVM_DEBUG(dbgs() << "LV: Selecting VF: " << ChosenFactor.Width << ".\n"); return ChosenFactor; } #endif @@ -7697,6 +7623,7 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() { "when vectorizing, the scalar cost must be computed."); #endif + LLVM_DEBUG(dbgs() << "LV: Selecting VF: " << BestFactor.Width << ".\n"); return BestFactor; } @@ -7802,7 +7729,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan( // Perform the actual loop transformation. VPTransformState State(&TTI, BestVF, BestUF, LI, DT, ILV.Builder, &ILV, - &BestVPlan, Legal->getWidestInductionType()); + &BestVPlan, OrigLoop->getParentLoop(), + Legal->getWidestInductionType()); #ifdef EXPENSIVE_CHECKS assert(DT->verify(DominatorTree::VerificationLevel::Fast)); @@ -7810,11 +7738,9 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan( // 0. Generate SCEV-dependent code in the entry, including TripCount, before // making any changes to the CFG. - if (!BestVPlan.getEntry()->empty()) { - State.CFG.PrevBB = OrigLoop->getLoopPreheader(); - State.Builder.SetInsertPoint(OrigLoop->getLoopPreheader()->getTerminator()); + if (!BestVPlan.getEntry()->empty()) BestVPlan.getEntry()->execute(&State); - } + if (!ILV.getTripCount()) ILV.setTripCount(State.get(BestVPlan.getTripCount(), VPLane(0))); else @@ -7823,6 +7749,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan( // 1. Set up the skeleton for vectorization, including vector pre-header and // middle block. The vector loop is created during VPlan execution. + VPBasicBlock *VectorPH = + cast<VPBasicBlock>(BestVPlan.getEntry()->getSingleSuccessor()); State.CFG.PrevBB = ILV.createVectorizedLoopSkeleton( ExpandedSCEVs ? *ExpandedSCEVs : State.ExpandedSCEVs); if (VectorizingEpilogue) @@ -7860,19 +7788,20 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan( BestVPlan.prepareToExecute( ILV.getTripCount(), ILV.getOrCreateVectorTripCount(ILV.LoopVectorPreHeader), State); + replaceVPBBWithIRVPBB(VectorPH, State.CFG.PrevBB); BestVPlan.execute(&State); - auto *ExitVPBB = BestVPlan.getMiddleBlock(); + auto *MiddleVPBB = BestVPlan.getMiddleBlock(); // 2.5 When vectorizing the epilogue, fix reduction and induction resume // values from the additional bypass block. if (VectorizingEpilogue) { assert(!ILV.Legal->hasUncountableEarlyExit() && "Epilogue vectorisation not yet supported with early exits"); BasicBlock *BypassBlock = ILV.getAdditionalBypassBlock(); - for (VPRecipeBase &R : *ExitVPBB) { + for (VPRecipeBase &R : *MiddleVPBB) { fixReductionScalarResumeWhenVectorizingEpilog( - &R, State, State.CFG.VPBB2IRBB[ExitVPBB], BypassBlock); + &R, State, State.CFG.VPBB2IRBB[MiddleVPBB], BypassBlock); } BasicBlock *PH = OrigLoop->getLoopPreheader(); for (const auto &[IVPhi, _] : Legal->getInductionVars()) { @@ -7885,30 +7814,31 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan( // 2.6. Maintain Loop Hints // Keep all loop hints from the original loop on the vector loop (we'll // replace the vectorizer-specific hints below). - MDNode *OrigLoopID = OrigLoop->getLoopID(); + if (auto *LoopRegion = BestVPlan.getVectorLoopRegion()) { + MDNode *OrigLoopID = OrigLoop->getLoopID(); - std::optional<MDNode *> VectorizedLoopID = - makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, - LLVMLoopVectorizeFollowupVectorized}); - - VPBasicBlock *HeaderVPBB = - BestVPlan.getVectorLoopRegion()->getEntryBasicBlock(); - Loop *L = LI->getLoopFor(State.CFG.VPBB2IRBB[HeaderVPBB]); - if (VectorizedLoopID) - L->setLoopID(*VectorizedLoopID); - else { - // Keep all loop hints from the original loop on the vector loop (we'll - // replace the vectorizer-specific hints below). - if (MDNode *LID = OrigLoop->getLoopID()) - L->setLoopID(LID); - - LoopVectorizeHints Hints(L, true, *ORE); - Hints.setAlreadyVectorized(); + std::optional<MDNode *> VectorizedLoopID = + makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, + LLVMLoopVectorizeFollowupVectorized}); + + VPBasicBlock *HeaderVPBB = LoopRegion->getEntryBasicBlock(); + Loop *L = LI->getLoopFor(State.CFG.VPBB2IRBB[HeaderVPBB]); + if (VectorizedLoopID) { + L->setLoopID(*VectorizedLoopID); + } else { + // Keep all loop hints from the original loop on the vector loop (we'll + // replace the vectorizer-specific hints below). + if (MDNode *LID = OrigLoop->getLoopID()) + L->setLoopID(LID); + + LoopVectorizeHints Hints(L, true, *ORE); + Hints.setAlreadyVectorized(); + } + TargetTransformInfo::UnrollingPreferences UP; + TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE); + if (!UP.UnrollVectorizedLoop || VectorizingEpilogue) + addRuntimeUnrollDisableMetaData(L); } - TargetTransformInfo::UnrollingPreferences UP; - TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE); - if (!UP.UnrollVectorizedLoop || VectorizingEpilogue) - addRuntimeUnrollDisableMetaData(L); // 3. Fix the vectorized code: take care of header phi's, live-outs, // predication, updating analyses. @@ -7917,15 +7847,18 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan( ILV.printDebugTracesAtEnd(); // 4. Adjust branch weight of the branch in the middle block. - auto *MiddleTerm = - cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator()); - if (MiddleTerm->isConditional() && - hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) { - // Assume that `Count % VectorTripCount` is equally distributed. - unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue(); - assert(TripCount > 0 && "trip count should not be zero"); - const uint32_t Weights[] = {1, TripCount - 1}; - setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false); + if (BestVPlan.getVectorLoopRegion()) { + auto *MiddleVPBB = BestVPlan.getMiddleBlock(); + auto *MiddleTerm = + cast<BranchInst>(State.CFG.VPBB2IRBB[MiddleVPBB]->getTerminator()); + if (MiddleTerm->isConditional() && + hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) { + // Assume that `Count % VectorTripCount` is equally distributed. + unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue(); + assert(TripCount > 0 && "trip count should not be zero"); + const uint32_t Weights[] = {1, TripCount - 1}; + setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false); + } } return State.ExpandedSCEVs; @@ -7968,17 +7901,6 @@ BasicBlock *EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton( // Generate the induction variable. EPI.VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); - // Generate VPValues and ResumePhi recipes for wide inductions in the epilogue - // plan only. Other inductions only need a resume value for the canonical - // induction, which will get created during epilogue skeleton construction. - SmallPtrSet<PHINode *, 4> WideIVs; - for (VPRecipeBase &H : - EPI.EpiloguePlan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { - if (auto *WideIV = dyn_cast<VPWidenInductionRecipe>(&H)) - WideIVs.insert(WideIV->getPHINode()); - } - createInductionResumeVPValues(ExpandedSCEVs, nullptr, &WideIVs); - return LoopVectorPreHeader; } @@ -8048,7 +7970,7 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, setBranchWeights(BI, MinItersBypassWeights, /*IsExpected=*/false); ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI); - introduceCheckBlockInVPlan(Plan, TCCheckBlock); + introduceCheckBlockInVPlan(TCCheckBlock); return TCCheckBlock; } @@ -8128,14 +8050,11 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton( Phi->removeIncomingValue(EPI.MemSafetyCheck); } - // Generate induction resume values. These variables save the new starting - // indexes for the scalar loop. They are used to test if there are any tail - // iterations left once the vector loop has completed. - // Note that when the vectorized epilogue is skipped due to iteration count - // check, then the resume value for the induction variable comes from - // the trip count of the main vector loop, passed as the second argument. - createInductionResumeVPValues(ExpandedSCEVs, EPI.VectorTripCount); - + // Generate bypass values from the additional bypass block. Note that when the + // vectorized epilogue is skipped due to iteration count check, then the + // resume value for the induction variable comes from the trip count of the + // main vector loop, passed as the second argument. + createInductionAdditionalBypassValues(ExpandedSCEVs, EPI.VectorTripCount); return LoopVectorPreHeader; } @@ -8185,13 +8104,13 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck( // A new entry block has been created for the epilogue VPlan. Hook it in, as // otherwise we would try to modify the entry to the main vector loop. - VPIRBasicBlock *NewEntry = VPIRBasicBlock::fromBasicBlock(Insert); + VPIRBasicBlock *NewEntry = Plan.createVPIRBasicBlock(Insert); VPBasicBlock *OldEntry = Plan.getEntry(); VPBlockUtils::reassociateBlocks(OldEntry, NewEntry); Plan.setEntry(NewEntry); - delete OldEntry; + // OldEntry is now dead and will be cleaned up when the plan gets destroyed. - introduceCheckBlockInVPlan(Plan, Insert); + introduceCheckBlockInVPlan(Insert); return Insert; } @@ -8435,17 +8354,22 @@ VPRecipeBuilder::tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands, auto *GEP = dyn_cast<GetElementPtrInst>( Ptr->getUnderlyingValue()->stripPointerCasts()); VPSingleDefRecipe *VectorPtr; - if (Reverse) + if (Reverse) { + // When folding the tail, we may compute an address that we don't in the + // original scalar loop and it may not be inbounds. Drop Inbounds in that + // case. + GEPNoWrapFlags Flags = + (CM.foldTailByMasking() || !GEP || !GEP->isInBounds()) + ? GEPNoWrapFlags::none() + : GEPNoWrapFlags::inBounds(); VectorPtr = new VPReverseVectorPointerRecipe( - Ptr, &Plan.getVF(), getLoadStoreType(I), - GEP && GEP->isInBounds() ? GEPNoWrapFlags::inBounds() - : GEPNoWrapFlags::none(), - I->getDebugLoc()); - else + Ptr, &Plan.getVF(), getLoadStoreType(I), Flags, I->getDebugLoc()); + } else { VectorPtr = new VPVectorPointerRecipe(Ptr, getLoadStoreType(I), GEP ? GEP->getNoWrapFlags() : GEPNoWrapFlags::none(), I->getDebugLoc()); + } Builder.getInsertBlock()->appendRecipe(VectorPtr); Ptr = VectorPtr; } @@ -8955,14 +8879,56 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW, {CanonicalIVIncrement, &Plan.getVectorTripCount()}, DL); } -/// Create resume phis in the scalar preheader for first-order recurrences and -/// reductions and update the VPIRInstructions wrapping the original phis in the -/// scalar header. +/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the +/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute +/// the end value of the induction. +static VPValue *addResumePhiRecipeForInduction(VPWidenInductionRecipe *WideIV, + VPBuilder &VectorPHBuilder, + VPBuilder &ScalarPHBuilder, + VPTypeAnalysis &TypeInfo, + VPValue *VectorTC) { + auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV); + // Truncated wide inductions resume from the last lane of their vector value + // in the last vector iteration which is handled elsewhere. + if (WideIntOrFp && WideIntOrFp->getTruncInst()) + return nullptr; + + VPValue *Start = WideIV->getStartValue(); + VPValue *Step = WideIV->getStepValue(); + const InductionDescriptor &ID = WideIV->getInductionDescriptor(); + VPValue *EndValue = VectorTC; + if (!WideIntOrFp || !WideIntOrFp->isCanonical()) { + EndValue = VectorPHBuilder.createDerivedIV( + ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()), + Start, VectorTC, Step); + } + + // EndValue is derived from the vector trip count (which has the same type as + // the widest induction) and thus may be wider than the induction here. + Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV); + if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) { + EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue, + ScalarTypeOfWideIV, + WideIV->getDebugLoc()); + } + + auto *ResumePhiRecipe = + ScalarPHBuilder.createNaryOp(VPInstruction::ResumePhi, {EndValue, Start}, + WideIV->getDebugLoc(), "bc.resume.val"); + return ResumePhiRecipe; +} + +/// Create resume phis in the scalar preheader for first-order recurrences, +/// reductions and inductions, and update the VPIRInstructions wrapping the +/// original phis in the scalar header. static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) { + VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType()); auto *ScalarPH = Plan.getScalarPreheader(); auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getSinglePredecessor()); - VPBuilder ScalarPHBuilder(ScalarPH); + VPBuilder VectorPHBuilder( + cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSinglePredecessor())); VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); + VPBuilder ScalarPHBuilder(ScalarPH); VPValue *OneVPV = Plan.getOrAddLiveIn( ConstantInt::get(Plan.getCanonicalIV()->getScalarType(), 1)); for (VPRecipeBase &ScalarPhiR : *Plan.getScalarHeader()) { @@ -8970,9 +8936,23 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) { auto *ScalarPhiI = dyn_cast<PHINode>(&ScalarPhiIRI->getInstruction()); if (!ScalarPhiI) break; + auto *VectorPhiR = cast<VPHeaderPHIRecipe>(Builder.getRecipe(ScalarPhiI)); - if (!isa<VPFirstOrderRecurrencePHIRecipe, VPReductionPHIRecipe>(VectorPhiR)) + if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) { + if (VPValue *ResumePhi = addResumePhiRecipeForInduction( + WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo, + &Plan.getVectorTripCount())) { + ScalarPhiIRI->addOperand(ResumePhi); + continue; + } + // TODO: Also handle truncated inductions here. Computing end-values + // separately should be done as VPlan-to-VPlan optimization, after + // legalizing all resume values to use the last lane from the loop. + assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() && + "should only skip truncated wide inductions"); continue; + } + // The backedge value provides the value to resume coming out of a loop, // which for FORs is a vector whose last element needs to be extracted. The // start value provides the value if the loop is bypassed. @@ -8990,14 +8970,73 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) { } } +/// Return true if \p VPV is an optimizable IV or IV use. That is, if \p VPV is +/// either an untruncated wide induction, or if it increments a wide induction +/// by its step. +static bool isOptimizableIVOrUse(VPValue *VPV) { + VPRecipeBase *Def = VPV->getDefiningRecipe(); + if (!Def) + return false; + auto *WideIV = dyn_cast<VPWidenInductionRecipe>(Def); + if (WideIV) { + // VPV itself is a wide induction, separately compute the end value for exit + // users if it is not a truncated IV. + return isa<VPWidenPointerInductionRecipe>(WideIV) || + !cast<VPWidenIntOrFpInductionRecipe>(WideIV)->getTruncInst(); + } + + // Check if VPV is an optimizable induction increment. + if (Def->getNumOperands() != 2) + return false; + WideIV = dyn_cast<VPWidenInductionRecipe>(Def->getOperand(0)); + if (!WideIV) + WideIV = dyn_cast<VPWidenInductionRecipe>(Def->getOperand(1)); + if (!WideIV) + return false; + + using namespace VPlanPatternMatch; + auto &ID = WideIV->getInductionDescriptor(); + + // Check if VPV increments the induction by the induction step. + VPValue *IVStep = WideIV->getStepValue(); + switch (ID.getInductionOpcode()) { + case Instruction::Add: + return match(VPV, m_c_Binary<Instruction::Add>(m_Specific(WideIV), + m_Specific(IVStep))); + case Instruction::FAdd: + return match(VPV, m_c_Binary<Instruction::FAdd>(m_Specific(WideIV), + m_Specific(IVStep))); + case Instruction::FSub: + return match(VPV, m_Binary<Instruction::FSub>(m_Specific(WideIV), + m_Specific(IVStep))); + case Instruction::Sub: { + // IVStep will be the negated step of the subtraction. Check if Step == -1 * + // IVStep. + VPValue *Step; + if (!match(VPV, m_Binary<Instruction::Sub>(m_VPValue(), m_VPValue(Step))) || + !Step->isLiveIn() || !IVStep->isLiveIn()) + return false; + auto *StepCI = dyn_cast<ConstantInt>(Step->getLiveInIRValue()); + auto *IVStepCI = dyn_cast<ConstantInt>(IVStep->getLiveInIRValue()); + return StepCI && IVStepCI && + StepCI->getValue() == (-1 * IVStepCI->getValue()); + } + default: + return ID.getKind() == InductionDescriptor::IK_PtrInduction && + match(VPV, m_GetElementPtr(m_Specific(WideIV), + m_Specific(WideIV->getStepValue()))); + } + llvm_unreachable("should have been covered by switch above"); +} + // Collect VPIRInstructions for phis in the exit blocks that are modeled // in VPlan and add the exiting VPValue as operand. Some exiting values are not // modeled explicitly yet and won't be included. Those are un-truncated // VPWidenIntOrFpInductionRecipe, VPWidenPointerInductionRecipe and induction // increments. -static SetVector<VPIRInstruction *> collectUsersInExitBlocks( - Loop *OrigLoop, VPRecipeBuilder &Builder, VPlan &Plan, - const MapVector<PHINode *, InductionDescriptor> &Inductions) { +static SetVector<VPIRInstruction *> +collectUsersInExitBlocks(Loop *OrigLoop, VPRecipeBuilder &Builder, + VPlan &Plan) { auto *MiddleVPBB = Plan.getMiddleBlock(); SetVector<VPIRInstruction *> ExitUsersToFix; for (VPIRBasicBlock *ExitVPBB : Plan.getExitBlocks()) { @@ -9022,18 +9061,9 @@ static SetVector<VPIRInstruction *> collectUsersInExitBlocks( // Exit values for inductions are computed and updated outside of VPlan // and independent of induction recipes. // TODO: Compute induction exit values in VPlan. - if ((isa<VPWidenIntOrFpInductionRecipe>(V) && - !cast<VPWidenIntOrFpInductionRecipe>(V)->getTruncInst()) || - isa<VPWidenPointerInductionRecipe>(V) || - (isa<Instruction>(IncomingValue) && - OrigLoop->contains(cast<Instruction>(IncomingValue)) && - any_of(IncomingValue->users(), [&Inductions](User *U) { - auto *P = dyn_cast<PHINode>(U); - return P && Inductions.contains(P); - }))) { - if (ExitVPBB->getSinglePredecessor() == MiddleVPBB) - continue; - } + if (isOptimizableIVOrUse(V) && + ExitVPBB->getSinglePredecessor() == MiddleVPBB) + continue; ExitUsersToFix.insert(ExitIRI); ExitIRI->addOperand(V); } @@ -9239,9 +9269,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { CM.getWideningDecision(IG->getInsertPos(), VF) == LoopVectorizationCostModel::CM_Interleave); // For scalable vectors, the only interleave factor currently supported - // must be power of 2 since we require the (de)interleave2 intrinsics - // instead of shufflevectors. - assert((!Result || !VF.isScalable() || isPowerOf2_32(IG->getFactor())) && + // is 2 since we require the (de)interleave2 intrinsics instead of + // shufflevectors. + assert((!Result || !VF.isScalable() || IG->getFactor() == 2) && "Unsupported interleave factor for scalable vectors"); return Result; }; @@ -9335,7 +9365,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { VPBB->appendRecipe(Recipe); } - VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB); + VPBlockUtils::insertBlockAfter(Plan->createVPBasicBlock(""), VPBB); VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor()); } @@ -9348,14 +9378,28 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { "VPBasicBlock"); RecipeBuilder.fixHeaderPhis(); + // Update wide induction increments to use the same step as the corresponding + // wide induction. This enables detecting induction increments directly in + // VPlan and removes redundant splats. + for (const auto &[Phi, ID] : Legal->getInductionVars()) { + auto *IVInc = cast<Instruction>( + Phi->getIncomingValueForBlock(OrigLoop->getLoopLatch())); + if (IVInc->getOperand(0) != Phi || IVInc->getOpcode() != Instruction::Add) + continue; + VPWidenInductionRecipe *WideIV = + cast<VPWidenInductionRecipe>(RecipeBuilder.getRecipe(Phi)); + VPRecipeBase *R = RecipeBuilder.getRecipe(IVInc); + R->setOperand(1, WideIV->getStepValue()); + } + if (auto *UncountableExitingBlock = Legal->getUncountableEarlyExitingBlock()) { VPlanTransforms::handleUncountableEarlyExit( *Plan, *PSE.getSE(), OrigLoop, UncountableExitingBlock, RecipeBuilder); } addScalarResumePhis(RecipeBuilder, *Plan); - SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlocks( - OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars()); + SetVector<VPIRInstruction *> ExitUsersToFix = + collectUsersInExitBlocks(OrigLoop, RecipeBuilder, *Plan); addExitUsersForFirstOrderRecurrences(*Plan, ExitUsersToFix); if (!addUsersInExitBlocks(*Plan, ExitUsersToFix)) { reportVectorizationFailure( @@ -9474,6 +9518,18 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { bool HasNUW = true; addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DebugLoc()); + + // Collect mapping of IR header phis to header phi recipes, to be used in + // addScalarResumePhis. + VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder); + for (auto &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { + if (isa<VPCanonicalIVPHIRecipe>(&R)) + continue; + auto *HeaderR = cast<VPHeaderPHIRecipe>(&R); + RecipeBuilder.setRecipe(HeaderR->getUnderlyingInstr(), HeaderR); + } + addScalarResumePhis(RecipeBuilder, *Plan); + assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid"); return Plan; } @@ -9762,13 +9818,18 @@ void VPDerivedIVRecipe::execute(VPTransformState &State) { State.Builder.setFastMathFlags(FPBinOp->getFastMathFlags()); Value *Step = State.get(getStepValue(), VPLane(0)); - Value *CanonicalIV = State.get(getOperand(1), VPLane(0)); + Value *Index = State.get(getOperand(1), VPLane(0)); Value *DerivedIV = emitTransformedIndex( - State.Builder, CanonicalIV, getStartValue()->getLiveInIRValue(), Step, - Kind, cast_if_present<BinaryOperator>(FPBinOp)); + State.Builder, Index, getStartValue()->getLiveInIRValue(), Step, Kind, + cast_if_present<BinaryOperator>(FPBinOp)); DerivedIV->setName(Name); - assert(DerivedIV != CanonicalIV && "IV didn't need transforming?"); - + // If index is the vector trip count, the concrete value will only be set in + // prepareToExecute, leading to missed simplifications, e.g. if it is 0. + // TODO: Remove the special case for the vector trip count once it is computed + // in VPlan and can be used during VPlan simplification. + assert((DerivedIV != Index || + getOperand(1) == &getParent()->getPlan()->getVectorTripCount()) && + "IV didn't need transforming?"); State.set(this, DerivedIV, VPLane(0)); } @@ -10078,6 +10139,57 @@ LoopVectorizePass::LoopVectorizePass(LoopVectorizeOptions Opts) VectorizeOnlyWhenForced(Opts.VectorizeOnlyWhenForced || !EnableLoopVectorization) {} +/// Prepare \p MainPlan for vectorizing the main vector loop during epilogue +/// vectorization. Remove ResumePhis from \p MainPlan for inductions that +/// don't have a corresponding wide induction in \p EpiPlan. +static void preparePlanForMainVectorLoop(VPlan &MainPlan, VPlan &EpiPlan) { + // Collect PHI nodes of widened phis in the VPlan for the epilogue. Those + // will need their resume-values computed in the main vector loop. Others + // can be removed from the main VPlan. + SmallPtrSet<PHINode *, 2> EpiWidenedPhis; + for (VPRecipeBase &R : + EpiPlan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { + if (isa<VPCanonicalIVPHIRecipe>(&R)) + continue; + EpiWidenedPhis.insert( + cast<PHINode>(R.getVPSingleValue()->getUnderlyingValue())); + } + for (VPRecipeBase &R : make_early_inc_range( + *cast<VPIRBasicBlock>(MainPlan.getScalarHeader()))) { + auto *VPIRInst = cast<VPIRInstruction>(&R); + auto *IRI = dyn_cast<PHINode>(&VPIRInst->getInstruction()); + if (!IRI) + break; + if (EpiWidenedPhis.contains(IRI)) + continue; + // There is no corresponding wide induction in the epilogue plan that would + // need a resume value. Remove the VPIRInst wrapping the scalar header phi + // together with the corresponding ResumePhi. The resume values for the + // scalar loop will be created during execution of EpiPlan. + VPRecipeBase *ResumePhi = VPIRInst->getOperand(0)->getDefiningRecipe(); + VPIRInst->eraseFromParent(); + ResumePhi->eraseFromParent(); + } + VPlanTransforms::removeDeadRecipes(MainPlan); + + using namespace VPlanPatternMatch; + VPBasicBlock *MainScalarPH = MainPlan.getScalarPreheader(); + VPValue *VectorTC = &MainPlan.getVectorTripCount(); + // If there is a suitable resume value for the canonical induction in the + // scalar (which will become vector) epilogue loop we are done. Otherwise + // create it below. + if (any_of(*MainScalarPH, [VectorTC](VPRecipeBase &R) { + return match(&R, m_VPInstruction<VPInstruction::ResumePhi>( + m_Specific(VectorTC), m_SpecificInt(0))); + })) + return; + VPBuilder ScalarPHBuilder(MainScalarPH, MainScalarPH->begin()); + ScalarPHBuilder.createNaryOp( + VPInstruction::ResumePhi, + {VectorTC, MainPlan.getCanonicalIV()->getStartValue()}, {}, + "vec.epilog.resume.val"); +} + /// Prepare \p Plan for vectorizing the epilogue loop. That is, re-use expanded /// SCEVs from \p ExpandedSCEVs and set resume values for header recipes. static void @@ -10542,12 +10654,12 @@ bool LoopVectorizePass::processLoop(Loop *L) { // to be vectorized by executing the plan (potentially with a different // factor) again shortly afterwards. VPlan &BestEpiPlan = LVP.getPlanFor(EpilogueVF.Width); + preparePlanForMainVectorLoop(*BestMainPlan, BestEpiPlan); EpilogueLoopVectorizationInfo EPI(VF.Width, IC, EpilogueVF.Width, 1, BestEpiPlan); EpilogueVectorizerMainLoop MainILV(L, PSE, LI, DT, TLI, TTI, AC, ORE, EPI, &LVL, &CM, BFI, PSI, Checks, *BestMainPlan); - auto ExpandedSCEVs = LVP.executePlan(EPI.MainLoopVF, EPI.MainLoopUF, *BestMainPlan, MainILV, DT, false); ++LoopsVectorized; diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index f52ddfd..36fed89 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -104,6 +104,7 @@ using namespace llvm; using namespace llvm::PatternMatch; using namespace slpvectorizer; +using namespace std::placeholders; #define SV_NAME "slp-vectorizer" #define DEBUG_TYPE "SLP" @@ -816,27 +817,34 @@ class InstructionsState { Instruction *AltOp = nullptr; public: - Instruction *getMainOp() const { return MainOp; } + Instruction *getMainOp() const { + assert(valid() && "InstructionsState is invalid."); + return MainOp; + } - Instruction *getAltOp() const { return AltOp; } + Instruction *getAltOp() const { + assert(valid() && "InstructionsState is invalid."); + return AltOp; + } /// The main/alternate opcodes for the list of instructions. - unsigned getOpcode() const { - return MainOp ? MainOp->getOpcode() : 0; - } + unsigned getOpcode() const { return getMainOp()->getOpcode(); } - unsigned getAltOpcode() const { - return AltOp ? AltOp->getOpcode() : 0; - } + unsigned getAltOpcode() const { return getAltOp()->getOpcode(); } /// Some of the instructions in the list have alternate opcodes. - bool isAltShuffle() const { return AltOp != MainOp; } + bool isAltShuffle() const { return getMainOp() != getAltOp(); } bool isOpcodeOrAlt(Instruction *I) const { unsigned CheckedOpcode = I->getOpcode(); return getOpcode() == CheckedOpcode || getAltOpcode() == CheckedOpcode; } + /// Checks if the current state is valid, i.e. has non-null MainOp + bool valid() const { return MainOp && AltOp; } + + explicit operator bool() const { return valid(); } + InstructionsState() = delete; InstructionsState(Instruction *MainOp, Instruction *AltOp) : MainOp(MainOp), AltOp(AltOp) {} @@ -869,8 +877,8 @@ static bool areCompatibleCmpOps(Value *BaseOp0, Value *BaseOp1, Value *Op0, (!isa<Instruction>(BaseOp0) && !isa<Instruction>(Op0) && !isa<Instruction>(BaseOp1) && !isa<Instruction>(Op1)) || BaseOp0 == Op0 || BaseOp1 == Op1 || - getSameOpcode({BaseOp0, Op0}, TLI).getOpcode() || - getSameOpcode({BaseOp1, Op1}, TLI).getOpcode(); + getSameOpcode({BaseOp0, Op0}, TLI) || + getSameOpcode({BaseOp1, Op1}, TLI); } /// \returns true if a compare instruction \p CI has similar "look" and @@ -1847,7 +1855,7 @@ public: InstructionsState S = getSameOpcode(Ops, TLI); // Note: Only consider instructions with <= 2 operands to avoid // complexity explosion. - if (S.getOpcode() && + if (S && (S.getMainOp()->getNumOperands() <= 2 || !MainAltOps.empty() || !S.isAltShuffle()) && all_of(Ops, [&S](Value *V) { @@ -2382,7 +2390,7 @@ public: // Use Boyer-Moore majority voting for finding the majority opcode and // the number of times it occurs. if (auto *I = dyn_cast<Instruction>(OpData.V)) { - if (!OpcodeI || !getSameOpcode({OpcodeI, I}, TLI).getOpcode() || + if (!OpcodeI || !getSameOpcode({OpcodeI, I}, TLI) || I->getParent() != Parent) { if (NumOpsWithSameOpcodeParent == 0) { NumOpsWithSameOpcodeParent = 1; @@ -2501,8 +2509,7 @@ public: // 2.1. If we have only 2 lanes, need to check that value in the // next lane does not build same opcode sequence. (Lns == 2 && - !getSameOpcode({Op, getValue((OpI + 1) % OpE, Ln)}, TLI) - .getOpcode() && + !getSameOpcode({Op, getValue((OpI + 1) % OpE, Ln)}, TLI) && isa<Constant>(Data.V)))) || // 3. The operand in the current lane is loop invariant (can be // hoisted out) and another operand is also a loop invariant @@ -2511,7 +2518,7 @@ public: // FIXME: need to teach the cost model about this case for better // estimation. (IsInvariant && !isa<Constant>(Data.V) && - !getSameOpcode({Op, Data.V}, TLI).getOpcode() && + !getSameOpcode({Op, Data.V}, TLI) && L->isLoopInvariant(Data.V))) { FoundCandidate = true; Data.IsUsed = Data.V == Op; @@ -2541,7 +2548,7 @@ public: return true; Value *OpILn = getValue(OpI, Ln); return (L && L->isLoopInvariant(OpILn)) || - (getSameOpcode({Op, OpILn}, TLI).getOpcode() && + (getSameOpcode({Op, OpILn}, TLI) && allSameBlock({Op, OpILn})); })) return true; @@ -2698,7 +2705,7 @@ public: OperandData &AltOp = getData(OpIdx, Lane); InstructionsState OpS = getSameOpcode({MainAltOps[OpIdx].front(), AltOp.V}, TLI); - if (OpS.getOpcode() && OpS.isAltShuffle()) + if (OpS && OpS.isAltShuffle()) MainAltOps[OpIdx].push_back(AltOp.V); } } @@ -3400,6 +3407,7 @@ private: } void setOperations(const InstructionsState &S) { + assert(S && "InstructionsState is invalid."); MainOp = S.getMainOp(); AltOp = S.getAltOp(); } @@ -3600,7 +3608,7 @@ private: "Need to vectorize gather entry?"); // Gathered loads still gathered? Do not create entry, use the original one. if (GatheredLoadsEntriesFirst.has_value() && - EntryState == TreeEntry::NeedToGather && + EntryState == TreeEntry::NeedToGather && S && S.getOpcode() == Instruction::Load && UserTreeIdx.EdgeIdx == UINT_MAX && !UserTreeIdx.UserTE) return nullptr; @@ -3618,7 +3626,8 @@ private: ReuseShuffleIndices.end()); if (ReorderIndices.empty()) { Last->Scalars.assign(VL.begin(), VL.end()); - Last->setOperations(S); + if (S) + Last->setOperations(S); } else { // Reorder scalars and build final mask. Last->Scalars.assign(VL.size(), nullptr); @@ -3629,7 +3638,8 @@ private: return VL[Idx]; }); InstructionsState S = getSameOpcode(Last->Scalars, *TLI); - Last->setOperations(S); + if (S) + Last->setOperations(S); Last->ReorderIndices.append(ReorderIndices.begin(), ReorderIndices.end()); } if (!Last->isGather()) { @@ -4774,8 +4784,7 @@ static bool arePointersCompatible(Value *Ptr1, Value *Ptr2, (!GEP2 || isConstant(GEP2->getOperand(1)))) || !CompareOpcodes || (GEP1 && GEP2 && - getSameOpcode({GEP1->getOperand(1), GEP2->getOperand(1)}, TLI) - .getOpcode())); + getSameOpcode({GEP1->getOperand(1), GEP2->getOperand(1)}, TLI))); } /// Calculates minimal alignment as a common alignment. @@ -4947,6 +4956,37 @@ getShuffleCost(const TargetTransformInfo &TTI, TTI::ShuffleKind Kind, return TTI.getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp, Args); } +/// Correctly creates insert_subvector, checking that the index is multiple of +/// the subvectors length. Otherwise, generates shuffle using \p Generator or +/// using default shuffle. +static Value *createInsertVector( + IRBuilderBase &Builder, Value *Vec, Value *V, unsigned Index, + function_ref<Value *(Value *, Value *, ArrayRef<int>)> Generator = {}) { + const unsigned SubVecVF = getNumElements(V->getType()); + if (Index % SubVecVF == 0) { + Vec = Builder.CreateInsertVector(Vec->getType(), Vec, V, + Builder.getInt64(Index)); + } else { + // Create shuffle, insertvector requires that index is multiple of + // the subvector length. + const unsigned VecVF = getNumElements(Vec->getType()); + SmallVector<int> Mask(VecVF, PoisonMaskElem); + std::iota(Mask.begin(), std::next(Mask.begin(), Index), 0); + for (unsigned I : seq<unsigned>(SubVecVF)) + Mask[I + Index] = I + VecVF; + if (Generator) { + Vec = Generator(Vec, V, Mask); + } else { + // 1. Resize V to the size of Vec. + SmallVector<int> ResizeMask(VecVF, PoisonMaskElem); + std::iota(ResizeMask.begin(), std::next(ResizeMask.begin(), SubVecVF), 0); + V = Builder.CreateShuffleVector(V, ResizeMask); + Vec = Builder.CreateShuffleVector(Vec, V, Mask); + } + } + return Vec; +} + BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, SmallVectorImpl<unsigned> &Order, @@ -5347,11 +5387,10 @@ static bool clusterSortPtrAccesses(ArrayRef<Value *> VL, SmallPtrSet<Value *, 13> SecondPointers; Value *P1 = Ptr1; Value *P2 = Ptr2; - if (P1 == P2) - return false; unsigned Depth = 0; - while (!FirstPointers.contains(P2) && !SecondPointers.contains(P1) && - Depth <= RecursionMaxDepth) { + while (!FirstPointers.contains(P2) && !SecondPointers.contains(P1)) { + if (P1 == P2 || Depth > RecursionMaxDepth) + return false; FirstPointers.insert(P1); SecondPointers.insert(P2); P1 = getUnderlyingObject(P1, /*MaxLookup=*/1); @@ -7500,7 +7539,7 @@ bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S, [&](ArrayRef<Value *> Op) { if (allConstant(Op) || (!isSplat(Op) && allSameBlock(Op) && allSameType(Op) && - getSameOpcode(Op, *TLI).getMainOp())) + getSameOpcode(Op, *TLI))) return false; DenseMap<Value *, unsigned> Uniques; for (Value *V : Op) { @@ -8071,15 +8110,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Don't go into catchswitch blocks, which can happen with PHIs. // Such blocks can only have PHIs and the catchswitch. There is no // place to insert a shuffle if we need to, so just avoid that issue. - if (S.getMainOp() && - isa<CatchSwitchInst>(S.getMainOp()->getParent()->getTerminator())) { + if (S && isa<CatchSwitchInst>(S.getMainOp()->getParent()->getTerminator())) { LLVM_DEBUG(dbgs() << "SLP: bundle in catchswitch block.\n"); newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return; } // Check if this is a duplicate of another entry. - if (S.getOpcode()) { + if (S) { if (TreeEntry *E = getTreeEntry(S.getMainOp())) { LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.getMainOp() << ".\n"); @@ -8140,13 +8178,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // a load), in which case peek through to include it in the tree, without // ballooning over-budget. if (Depth >= RecursionMaxDepth && - !(S.getMainOp() && !S.isAltShuffle() && VL.size() >= 4 && + !(S && !S.isAltShuffle() && VL.size() >= 4 && (match(S.getMainOp(), m_Load(m_Value())) || all_of(VL, [&S](const Value *I) { return match(I, m_OneUse(m_ZExtOrSExt(m_OneUse(m_Load(m_Value()))))) && - cast<Instruction>(I)->getOpcode() == - S.getMainOp()->getOpcode(); + cast<Instruction>(I)->getOpcode() == S.getOpcode(); })))) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); if (TryToFindDuplicates(S)) @@ -8156,7 +8193,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } // Don't handle scalable vectors - if (S.getOpcode() == Instruction::ExtractElement && + if (S && S.getOpcode() == Instruction::ExtractElement && isa<ScalableVectorType>( cast<ExtractElementInst>(S.getMainOp())->getVectorOperandType())) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to scalable vector type.\n"); @@ -8180,7 +8217,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // vectorize. auto &&NotProfitableForVectorization = [&S, this, Depth](ArrayRef<Value *> VL) { - if (!S.getOpcode() || !S.isAltShuffle() || VL.size() > 2) + if (!S || !S.isAltShuffle() || VL.size() > 2) return false; if (VectorizableTree.size() < MinTreeSize) return false; @@ -8235,7 +8272,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, bool IsScatterVectorizeUserTE = UserTreeIdx.UserTE && UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize; - bool AreAllSameBlock = S.getOpcode() && allSameBlock(VL); + bool AreAllSameBlock = S && allSameBlock(VL); bool AreScatterAllGEPSameBlock = (IsScatterVectorizeUserTE && VL.front()->getType()->isPointerTy() && VL.size() > 2 && @@ -8252,8 +8289,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, sortPtrAccesses(VL, UserTreeIdx.UserTE->getMainOp()->getType(), *DL, *SE, SortedIndices)); bool AreAllSameInsts = AreAllSameBlock || AreScatterAllGEPSameBlock; - if (!AreAllSameInsts || (!S.getOpcode() && allConstant(VL)) || isSplat(VL) || - (isa_and_present<InsertElementInst, ExtractValueInst, ExtractElementInst>( + if (!AreAllSameInsts || (!S && allConstant(VL)) || isSplat(VL) || + (S && + isa<InsertElementInst, ExtractValueInst, ExtractElementInst>( S.getMainOp()) && !all_of(VL, isVectorLikeInstWithConstOps)) || NotProfitableForVectorization(VL)) { @@ -8265,7 +8303,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } // Don't vectorize ephemeral values. - if (S.getOpcode() && !EphValues.empty()) { + if (S && !EphValues.empty()) { for (Value *V : VL) { if (EphValues.count(V)) { LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V @@ -8324,7 +8362,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Instruction *VL0 = S.getMainOp(); BB = VL0->getParent(); - if (S.getMainOp() && + if (S && (BB->isEHPad() || isa_and_nonnull<UnreachableInst>(BB->getTerminator()) || !DT->isReachableFromEntry(BB))) { // Don't go into unreachable blocks. They may contain instructions with @@ -8378,8 +8416,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } LLVM_DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); - unsigned ShuffleOrOp = S.isAltShuffle() ? - (unsigned) Instruction::ShuffleVector : S.getOpcode(); + unsigned ShuffleOrOp = + S.isAltShuffle() ? (unsigned)Instruction::ShuffleVector : S.getOpcode(); auto CreateOperandNodes = [&](TreeEntry *TE, const auto &Operands) { // Postpone PHI nodes creation SmallVector<unsigned> PHIOps; @@ -8388,7 +8426,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (Op.empty()) continue; InstructionsState S = getSameOpcode(Op, *TLI); - if (S.getOpcode() != Instruction::PHI || S.isAltShuffle()) + if ((!S || S.getOpcode() != Instruction::PHI) || S.isAltShuffle()) buildTree_rec(Op, Depth + 1, {TE, I}); else PHIOps.push_back(I); @@ -9771,7 +9809,7 @@ void BoUpSLP::transformNodes() { if (IsSplat) continue; InstructionsState S = getSameOpcode(Slice, *TLI); - if (!S.getOpcode() || S.isAltShuffle() || !allSameBlock(Slice) || + if (!S || S.isAltShuffle() || !allSameBlock(Slice) || (S.getOpcode() == Instruction::Load && areKnownNonVectorizableLoads(Slice)) || (S.getOpcode() != Instruction::Load && !has_single_bit(VF))) @@ -11086,7 +11124,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, if (const TreeEntry *OpTE = getTreeEntry(V)) return getCastContextHint(*OpTE); InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI); - if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle()) + if (SrcState && SrcState.getOpcode() == Instruction::Load && + !SrcState.isAltShuffle()) return TTI::CastContextHint::GatherScatter; return TTI::CastContextHint::None; }; @@ -13265,7 +13304,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( Value *In1 = PHI1->getIncomingValue(I); if (isConstant(In) && isConstant(In1)) continue; - if (!getSameOpcode({In, In1}, *TLI).getOpcode()) + if (!getSameOpcode({In, In1}, *TLI)) return false; if (cast<Instruction>(In)->getParent() != cast<Instruction>(In1)->getParent()) @@ -13293,7 +13332,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( if (It != UsedValuesEntry.end()) UsedInSameVTE = It->second == UsedValuesEntry.find(V)->second; return V != V1 && MightBeIgnored(V1) && !UsedInSameVTE && - getSameOpcode({V, V1}, *TLI).getOpcode() && + getSameOpcode({V, V1}, *TLI) && cast<Instruction>(V)->getParent() == cast<Instruction>(V1)->getParent() && (!isa<PHINode>(V1) || AreCompatiblePHIs(V, V1)); @@ -13876,9 +13915,8 @@ Value *BoUpSLP::gather( Instruction *InsElt; if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) { assert(SLPReVec && "FixedVectorType is not expected."); - Vec = InsElt = Builder.CreateInsertVector( - Vec->getType(), Vec, Scalar, - Builder.getInt64(Pos * VecTy->getNumElements())); + Vec = InsElt = cast<Instruction>(createInsertVector( + Builder, Vec, Scalar, Pos * getNumElements(VecTy))); auto *II = dyn_cast<IntrinsicInst>(InsElt); if (!II || II->getIntrinsicID() != Intrinsic::vector_insert) return Vec; @@ -14478,23 +14516,10 @@ public: V, SimplifyQuery(*R.DL)); })); unsigned InsertionIndex = Idx * ScalarTyNumElements; - const unsigned SubVecVF = - cast<FixedVectorType>(V->getType())->getNumElements(); - if (InsertionIndex % SubVecVF == 0) { - Vec = Builder.CreateInsertVector(Vec->getType(), Vec, V, - Builder.getInt64(InsertionIndex)); - } else { - // Create shuffle, insertvector requires that index is multiple of - // the subvectors length. - const unsigned VecVF = - cast<FixedVectorType>(Vec->getType())->getNumElements(); - SmallVector<int> Mask(VecVF, PoisonMaskElem); - std::iota(Mask.begin(), Mask.end(), 0); - for (unsigned I : seq<unsigned>( - InsertionIndex, (Idx + SubVecVF) * ScalarTyNumElements)) - Mask[I] = I - Idx + VecVF; - Vec = createShuffle(Vec, V, Mask); - } + Vec = createInsertVector( + Builder, Vec, V, InsertionIndex, + std::bind(&ShuffleInstructionBuilder::createShuffle, this, _1, _2, + _3)); if (!CommonMask.empty()) { std::iota( std::next(CommonMask.begin(), InsertionIndex), @@ -14560,12 +14585,12 @@ BoUpSLP::TreeEntry *BoUpSLP::getMatchedVectorizedOperand(const TreeEntry *E, ArrayRef<Value *> VL = E->getOperand(NodeIdx); InstructionsState S = getSameOpcode(VL, *TLI); // Special processing for GEPs bundle, which may include non-gep values. - if (!S.getOpcode() && VL.front()->getType()->isPointerTy()) { + if (!S && VL.front()->getType()->isPointerTy()) { const auto *It = find_if(VL, IsaPred<GetElementPtrInst>); if (It != VL.end()) S = getSameOpcode(*It, *TLI); } - if (!S.getOpcode()) + if (!S) return nullptr; auto CheckSameVE = [&](const TreeEntry *VE) { return VE->isSame(VL) && @@ -17740,7 +17765,6 @@ bool BoUpSLP::collectValuesToDemote( BitWidth = std::max(BitWidth, BitWidth1); return BitWidth > 0 && OrigBitWidth >= (BitWidth * 2); }; - using namespace std::placeholders; auto FinalAnalysis = [&]() { if (!IsProfitableToDemote) return false; @@ -18546,8 +18570,7 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, hasFullVectorsOrPowerOf2(*TTI, ValOps.front()->getType(), ValOps.size()) || (VectorizeNonPowerOf2 && has_single_bit(ValOps.size() + 1)); - if ((!IsAllowedSize && S.getOpcode() && - S.getOpcode() != Instruction::Load && + if ((!IsAllowedSize && S && S.getOpcode() != Instruction::Load && (!S.getMainOp()->isSafeToRemove() || any_of(ValOps.getArrayRef(), [&](Value *V) { @@ -18557,8 +18580,8 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, return !Stores.contains(U); })); }))) || - (ValOps.size() > Chain.size() / 2 && !S.getOpcode())) { - Size = (!IsAllowedSize && S.getOpcode()) ? 1 : 2; + (ValOps.size() > Chain.size() / 2 && !S)) { + Size = (!IsAllowedSize && S) ? 1 : 2; return false; } } @@ -18581,7 +18604,7 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, R.computeMinimumValueSizes(); Size = R.getCanonicalGraphSize(); - if (S.getOpcode() == Instruction::Load) + if (S && S.getOpcode() == Instruction::Load) Size = 2; // cut off masked gather small trees InstructionCost Cost = R.getTreeCost(); @@ -19082,7 +19105,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, // Check that all of the parts are instructions of the same type, // we permit an alternate opcode via InstructionsState. InstructionsState S = getSameOpcode(VL, *TLI); - if (!S.getOpcode()) + if (!S) return false; Instruction *I0 = S.getMainOp(); @@ -19906,16 +19929,16 @@ public: // Also check if the instruction was folded to constant/other value. auto *Inst = dyn_cast<Instruction>(RdxVal); if ((Inst && isVectorLikeInstWithConstOps(Inst) && - (!S.getOpcode() || !S.isOpcodeOrAlt(Inst))) || - (S.getOpcode() && !Inst)) + (!S || !S.isOpcodeOrAlt(Inst))) || + (S && !Inst)) continue; Candidates.push_back(RdxVal); TrackedToOrig.try_emplace(RdxVal, OrigReducedVals[Cnt]); } bool ShuffledExtracts = false; // Try to handle shuffled extractelements. - if (S.getOpcode() == Instruction::ExtractElement && !S.isAltShuffle() && - I + 1 < E) { + if (S && S.getOpcode() == Instruction::ExtractElement && + !S.isAltShuffle() && I + 1 < E) { SmallVector<Value *> CommonCandidates(Candidates); for (Value *RV : ReducedVals[I + 1]) { Value *RdxVal = TrackedVals.at(RV); @@ -21310,7 +21333,7 @@ static bool compareCmp(Value *V, Value *V2, TargetLibraryInfo &TLI, return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); } InstructionsState S = getSameOpcode({I1, I2}, TLI); - if (S.getOpcode() && (IsCompatibility || !S.isAltShuffle())) + if (S && (IsCompatibility || !S.isAltShuffle())) continue; if (IsCompatibility) return false; @@ -21468,7 +21491,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (NodeI1 != NodeI2) return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); InstructionsState S = getSameOpcode({I1, I2}, *TLI); - if (S.getOpcode() && !S.isAltShuffle()) + if (S && !S.isAltShuffle()) continue; return I1->getOpcode() < I2->getOpcode(); } @@ -21531,8 +21554,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { return false; if (I1->getParent() != I2->getParent()) return false; - InstructionsState S = getSameOpcode({I1, I2}, *TLI); - if (S.getOpcode()) + if (getSameOpcode({I1, I2}, *TLI)) continue; return false; } @@ -21904,8 +21926,7 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { if (auto *I2 = dyn_cast<Instruction>(V2->getValueOperand())) { if (I1->getParent() != I2->getParent()) return false; - InstructionsState S = getSameOpcode({I1, I2}, *TLI); - return S.getOpcode() > 0; + return getSameOpcode({I1, I2}, *TLI).valid(); } if (isa<Constant>(V1->getValueOperand()) && isa<Constant>(V2->getValueOperand())) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 9a08292..e804f81 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -205,11 +205,6 @@ VPBlockBase *VPBlockBase::getEnclosingBlockWithPredecessors() { return Parent->getEnclosingBlockWithPredecessors(); } -void VPBlockBase::deleteCFG(VPBlockBase *Entry) { - for (VPBlockBase *Block : to_vector(vp_depth_first_shallow(Entry))) - delete Block; -} - VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() { iterator It = begin(); while (It != end() && It->isPhi()) @@ -221,9 +216,10 @@ VPTransformState::VPTransformState(const TargetTransformInfo *TTI, ElementCount VF, unsigned UF, LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder, InnerLoopVectorizer *ILV, VPlan *Plan, - Type *CanonicalIVTy) + Loop *CurrentParentLoop, Type *CanonicalIVTy) : TTI(TTI), VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan), - LVer(nullptr), TypeAnalysis(CanonicalIVTy) {} + CurrentParentLoop(CurrentParentLoop), LVer(nullptr), + TypeAnalysis(CanonicalIVTy) {} Value *VPTransformState::get(VPValue *Def, const VPLane &Lane) { if (Def->isLiveIn()) @@ -474,6 +470,13 @@ void VPIRBasicBlock::execute(VPTransformState *State) { connectToPredecessors(State->CFG); } +VPIRBasicBlock *VPIRBasicBlock::clone() { + auto *NewBlock = getPlan()->createEmptyVPIRBasicBlock(IRBB); + for (VPRecipeBase &R : Recipes) + NewBlock->appendRecipe(R.clone()); + return NewBlock; +} + void VPBasicBlock::execute(VPTransformState *State) { bool Replica = bool(State->Lane); BasicBlock *NewBB = State->CFG.PrevBB; // Reuse it if possible. @@ -484,11 +487,9 @@ void VPBasicBlock::execute(VPTransformState *State) { }; // 1. Create an IR basic block. - if (this == getPlan()->getVectorPreheader() || - (Replica && this == getParent()->getEntry()) || + if ((Replica && this == getParent()->getEntry()) || IsReplicateRegion(getSingleHierarchicalPredecessor())) { // Reuse the previous basic block if the current VPBB is either - // * the vector preheader, // * the entry to a replicate region, or // * the exit of a replicate region. State->CFG.VPBB2IRBB[this] = NewBB; @@ -500,8 +501,8 @@ void VPBasicBlock::execute(VPTransformState *State) { UnreachableInst *Terminator = State->Builder.CreateUnreachable(); // Register NewBB in its loop. In innermost loops its the same for all // BB's. - if (State->CurrentVectorLoop) - State->CurrentVectorLoop->addBasicBlockToLoop(NewBB, *State->LI); + if (State->CurrentParentLoop) + State->CurrentParentLoop->addBasicBlockToLoop(NewBB, *State->LI); State->Builder.SetInsertPoint(Terminator); State->CFG.PrevBB = NewBB; @@ -513,14 +514,11 @@ void VPBasicBlock::execute(VPTransformState *State) { executeRecipes(State, NewBB); } -void VPBasicBlock::dropAllReferences(VPValue *NewValue) { - for (VPRecipeBase &R : Recipes) { - for (auto *Def : R.definedValues()) - Def->replaceAllUsesWith(NewValue); - - for (unsigned I = 0, E = R.getNumOperands(); I != E; I++) - R.setOperand(I, NewValue); - } +VPBasicBlock *VPBasicBlock::clone() { + auto *NewBlock = getPlan()->createVPBasicBlock(getName()); + for (VPRecipeBase &R : *this) + NewBlock->appendRecipe(R.clone()); + return NewBlock; } void VPBasicBlock::executeRecipes(VPTransformState *State, BasicBlock *BB) { @@ -541,7 +539,7 @@ VPBasicBlock *VPBasicBlock::splitAt(iterator SplitAt) { SmallVector<VPBlockBase *, 2> Succs(successors()); // Create new empty block after the block to split. - auto *SplitBlock = new VPBasicBlock(getName() + ".split"); + auto *SplitBlock = getPlan()->createVPBasicBlock(getName() + ".split"); VPBlockUtils::insertBlockAfter(SplitBlock, this); // Finally, move the recipes starting at SplitAt to new block. @@ -557,7 +555,9 @@ VPBasicBlock *VPBasicBlock::splitAt(iterator SplitAt) { template <typename T> static T *getEnclosingLoopRegionForRegion(T *P) { if (P && P->isReplicator()) { P = P->getParent(); - assert(!cast<VPRegionBlock>(P)->isReplicator() && + // Multiple loop regions can be nested, but replicate regions can only be + // nested inside a loop region or must be outside any other region. + assert((!P || !cast<VPRegionBlock>(P)->isReplicator()) && "unexpected nested replicate regions"); } return P; @@ -701,37 +701,30 @@ static std::pair<VPBlockBase *, VPBlockBase *> cloneFrom(VPBlockBase *Entry) { VPRegionBlock *VPRegionBlock::clone() { const auto &[NewEntry, NewExiting] = cloneFrom(getEntry()); - auto *NewRegion = - new VPRegionBlock(NewEntry, NewExiting, getName(), isReplicator()); + auto *NewRegion = getPlan()->createVPRegionBlock(NewEntry, NewExiting, + getName(), isReplicator()); for (VPBlockBase *Block : vp_depth_first_shallow(NewEntry)) Block->setParent(NewRegion); return NewRegion; } -void VPRegionBlock::dropAllReferences(VPValue *NewValue) { - for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) - // Drop all references in VPBasicBlocks and replace all uses with - // DummyValue. - Block->dropAllReferences(NewValue); -} - void VPRegionBlock::execute(VPTransformState *State) { ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT(Entry); if (!isReplicator()) { // Create and register the new vector loop. - Loop *PrevLoop = State->CurrentVectorLoop; - State->CurrentVectorLoop = State->LI->AllocateLoop(); + Loop *PrevLoop = State->CurrentParentLoop; + State->CurrentParentLoop = State->LI->AllocateLoop(); BasicBlock *VectorPH = State->CFG.VPBB2IRBB[getPreheaderVPBB()]; Loop *ParentLoop = State->LI->getLoopFor(VectorPH); // Insert the new loop into the loop nest and register the new basic blocks // before calling any utilities such as SCEV that require valid LoopInfo. if (ParentLoop) - ParentLoop->addChildLoop(State->CurrentVectorLoop); + ParentLoop->addChildLoop(State->CurrentParentLoop); else - State->LI->addTopLevelLoop(State->CurrentVectorLoop); + State->LI->addTopLevelLoop(State->CurrentParentLoop); // Visit the VPBlocks connected to "this", starting from it. for (VPBlockBase *Block : RPOT) { @@ -739,7 +732,7 @@ void VPRegionBlock::execute(VPTransformState *State) { Block->execute(State); } - State->CurrentVectorLoop = PrevLoop; + State->CurrentParentLoop = PrevLoop; return; } @@ -822,17 +815,26 @@ void VPRegionBlock::print(raw_ostream &O, const Twine &Indent, #endif VPlan::VPlan(Loop *L) { - setEntry(VPIRBasicBlock::fromBasicBlock(L->getLoopPreheader())); - ScalarHeader = VPIRBasicBlock::fromBasicBlock(L->getHeader()); + setEntry(createVPIRBasicBlock(L->getLoopPreheader())); + ScalarHeader = createVPIRBasicBlock(L->getHeader()); } VPlan::~VPlan() { - if (Entry) { - VPValue DummyValue; - for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) - Block->dropAllReferences(&DummyValue); - - VPBlockBase::deleteCFG(Entry); + VPValue DummyValue; + + for (auto *VPB : CreatedBlocks) { + if (auto *VPBB = dyn_cast<VPBasicBlock>(VPB)) { + // Replace all operands of recipes and all VPValues defined in VPBB with + // DummyValue so the block can be deleted. + for (VPRecipeBase &R : *VPBB) { + for (auto *Def : R.definedValues()) + Def->replaceAllUsesWith(&DummyValue); + + for (unsigned I = 0, E = R.getNumOperands(); I != E; I++) + R.setOperand(I, &DummyValue); + } + } + delete VPB; } for (VPValue *VPV : VPLiveInsToFree) delete VPV; @@ -840,14 +842,6 @@ VPlan::~VPlan() { delete BackedgeTakenCount; } -VPIRBasicBlock *VPIRBasicBlock::fromBasicBlock(BasicBlock *IRBB) { - auto *VPIRBB = new VPIRBasicBlock(IRBB); - for (Instruction &I : - make_range(IRBB->begin(), IRBB->getTerminator()->getIterator())) - VPIRBB->appendRecipe(new VPIRInstruction(I)); - return VPIRBB; -} - VPlanPtr VPlan::createInitialVPlan(Type *InductionTy, PredicatedScalarEvolution &PSE, bool RequiresScalarEpilogueCheck, @@ -861,7 +855,7 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy, // an epilogue vector loop, the original entry block here will be replaced by // a new VPIRBasicBlock wrapping the entry to the epilogue vector loop after // generating code for the main vector loop. - VPBasicBlock *VecPreheader = new VPBasicBlock("vector.ph"); + VPBasicBlock *VecPreheader = Plan->createVPBasicBlock("vector.ph"); VPBlockUtils::connectBlocks(Plan->getEntry(), VecPreheader); // Create SCEV and VPValue for the trip count. @@ -878,17 +872,17 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy, // Create VPRegionBlock, with empty header and latch blocks, to be filled // during processing later. - VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body"); - VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch"); + VPBasicBlock *HeaderVPBB = Plan->createVPBasicBlock("vector.body"); + VPBasicBlock *LatchVPBB = Plan->createVPBasicBlock("vector.latch"); VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB); - auto *TopRegion = new VPRegionBlock(HeaderVPBB, LatchVPBB, "vector loop", - false /*isReplicator*/); + auto *TopRegion = Plan->createVPRegionBlock( + HeaderVPBB, LatchVPBB, "vector loop", false /*isReplicator*/); VPBlockUtils::insertBlockAfter(TopRegion, VecPreheader); - VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block"); + VPBasicBlock *MiddleVPBB = Plan->createVPBasicBlock("middle.block"); VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion); - VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph"); + VPBasicBlock *ScalarPH = Plan->createVPBasicBlock("scalar.ph"); VPBlockUtils::connectBlocks(ScalarPH, ScalarHeader); if (!RequiresScalarEpilogueCheck) { VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH); @@ -904,7 +898,7 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy, // we unconditionally branch to the scalar preheader. Do nothing. // 3) Otherwise, construct a runtime check. BasicBlock *IRExitBlock = TheLoop->getUniqueLatchExitBlock(); - auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock); + auto *VPExitBlock = Plan->createVPIRBasicBlock(IRExitBlock); // The connection order corresponds to the operands of the conditional branch. VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB); VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH); @@ -942,7 +936,8 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, IRBuilder<> Builder(State.CFG.PrevBB->getTerminator()); // FIXME: Model VF * UF computation completely in VPlan. - assert(VFxUF.getNumUsers() && "VFxUF expected to always have users"); + assert((!getVectorLoopRegion() || VFxUF.getNumUsers()) && + "VFxUF expected to always have users"); unsigned UF = getUF(); if (VF.getNumUsers()) { Value *RuntimeVF = getRuntimeVF(Builder, TCTy, State.VF); @@ -955,22 +950,6 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, } } -/// Replace \p VPBB with a VPIRBasicBlock wrapping \p IRBB. All recipes from \p -/// VPBB are moved to the end of the newly created VPIRBasicBlock. VPBB must -/// have a single predecessor, which is rewired to the new VPIRBasicBlock. All -/// successors of VPBB, if any, are rewired to the new VPIRBasicBlock. -static void replaceVPBBWithIRVPBB(VPBasicBlock *VPBB, BasicBlock *IRBB) { - VPIRBasicBlock *IRVPBB = VPIRBasicBlock::fromBasicBlock(IRBB); - for (auto &R : make_early_inc_range(*VPBB)) { - assert(!R.isPhi() && "Tried to move phi recipe to end of block"); - R.moveBefore(*IRVPBB, IRVPBB->end()); - } - - VPBlockUtils::reassociateBlocks(VPBB, IRVPBB); - - delete VPBB; -} - /// Generate the code inside the preheader and body of the vectorized loop. /// Assumes a single pre-header basic-block was created for this. Introduce /// additional basic-blocks as needed, and fill them all. @@ -978,25 +957,13 @@ void VPlan::execute(VPTransformState *State) { // Initialize CFG state. State->CFG.PrevVPBB = nullptr; State->CFG.ExitBB = State->CFG.PrevBB->getSingleSuccessor(); - BasicBlock *VectorPreHeader = State->CFG.PrevBB; - State->Builder.SetInsertPoint(VectorPreHeader->getTerminator()); // Disconnect VectorPreHeader from ExitBB in both the CFG and DT. + BasicBlock *VectorPreHeader = State->CFG.PrevBB; cast<BranchInst>(VectorPreHeader->getTerminator())->setSuccessor(0, nullptr); State->CFG.DTU.applyUpdates( {{DominatorTree::Delete, VectorPreHeader, State->CFG.ExitBB}}); - // Replace regular VPBB's for the vector preheader, middle and scalar - // preheader blocks with VPIRBasicBlocks wrapping their IR blocks. The IR - // blocks are created during skeleton creation, so we can only create the - // VPIRBasicBlocks now during VPlan execution rather than earlier during VPlan - // construction. - BasicBlock *MiddleBB = State->CFG.ExitBB; - BasicBlock *ScalarPh = MiddleBB->getSingleSuccessor(); - replaceVPBBWithIRVPBB(getVectorPreheader(), VectorPreHeader); - replaceVPBBWithIRVPBB(getMiddleBlock(), MiddleBB); - replaceVPBBWithIRVPBB(getScalarPreheader(), ScalarPh); - LLVM_DEBUG(dbgs() << "Executing best plan with VF=" << State->VF << ", UF=" << getUF() << '\n'); setName("Final VPlan"); @@ -1005,6 +972,8 @@ void VPlan::execute(VPTransformState *State) { // Disconnect the middle block from its single successor (the scalar loop // header) in both the CFG and DT. The branch will be recreated during VPlan // execution. + BasicBlock *MiddleBB = State->CFG.ExitBB; + BasicBlock *ScalarPh = MiddleBB->getSingleSuccessor(); auto *BrInst = new UnreachableInst(MiddleBB->getContext()); BrInst->insertBefore(MiddleBB->getTerminator()); MiddleBB->getTerminator()->eraseFromParent(); @@ -1022,12 +991,18 @@ void VPlan::execute(VPTransformState *State) { for (VPBlockBase *Block : RPOT) Block->execute(State); - VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock(); + State->CFG.DTU.flush(); + + auto *LoopRegion = getVectorLoopRegion(); + if (!LoopRegion) + return; + + VPBasicBlock *LatchVPBB = LoopRegion->getExitingBasicBlock(); BasicBlock *VectorLatchBB = State->CFG.VPBB2IRBB[LatchVPBB]; // Fix the latch value of canonical, reduction and first-order recurrences // phis in the vector loop. - VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock(); + VPBasicBlock *Header = LoopRegion->getEntryBasicBlock(); for (VPRecipeBase &R : Header->phis()) { // Skip phi-like recipes that generate their backedege values themselves. if (isa<VPWidenPHIRecipe>(&R)) @@ -1066,8 +1041,6 @@ void VPlan::execute(VPTransformState *State) { Value *Val = State->get(PhiR->getBackedgeValue(), NeedsScalar); cast<PHINode>(Phi)->addIncoming(Val, VectorLatchBB); } - - State->CFG.DTU.flush(); } InstructionCost VPlan::cost(ElementCount VF, VPCostContext &Ctx) { @@ -1080,14 +1053,14 @@ VPRegionBlock *VPlan::getVectorLoopRegion() { // TODO: Cache if possible. for (VPBlockBase *B : vp_depth_first_shallow(getEntry())) if (auto *R = dyn_cast<VPRegionBlock>(B)) - return R; + return R->isReplicator() ? nullptr : R; return nullptr; } const VPRegionBlock *VPlan::getVectorLoopRegion() const { for (const VPBlockBase *B : vp_depth_first_shallow(getEntry())) if (auto *R = dyn_cast<VPRegionBlock>(B)) - return R; + return R->isReplicator() ? nullptr : R; return nullptr; } @@ -1217,6 +1190,7 @@ static void remapOperands(VPBlockBase *Entry, VPBlockBase *NewEntry, } VPlan *VPlan::duplicate() { + unsigned NumBlocksBeforeCloning = CreatedBlocks.size(); // Clone blocks. const auto &[NewEntry, __] = cloneFrom(Entry); @@ -1257,9 +1231,32 @@ VPlan *VPlan::duplicate() { assert(Old2NewVPValues.contains(TripCount) && "TripCount must have been added to Old2NewVPValues"); NewPlan->TripCount = Old2NewVPValues[TripCount]; + + // Transfer all cloned blocks (the second half of all current blocks) from + // current to new VPlan. + unsigned NumBlocksAfterCloning = CreatedBlocks.size(); + for (unsigned I : + seq<unsigned>(NumBlocksBeforeCloning, NumBlocksAfterCloning)) + NewPlan->CreatedBlocks.push_back(this->CreatedBlocks[I]); + CreatedBlocks.truncate(NumBlocksBeforeCloning); + return NewPlan; } +VPIRBasicBlock *VPlan::createEmptyVPIRBasicBlock(BasicBlock *IRBB) { + auto *VPIRBB = new VPIRBasicBlock(IRBB); + CreatedBlocks.push_back(VPIRBB); + return VPIRBB; +} + +VPIRBasicBlock *VPlan::createVPIRBasicBlock(BasicBlock *IRBB) { + auto *VPIRBB = createEmptyVPIRBasicBlock(IRBB); + for (Instruction &I : + make_range(IRBB->begin(), IRBB->getTerminator()->getIterator())) + VPIRBB->appendRecipe(new VPIRInstruction(I)); + return VPIRBB; +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) Twine VPlanPrinter::getUID(const VPBlockBase *Block) { @@ -1409,11 +1406,17 @@ void VPlanIngredient::print(raw_ostream &O) const { #endif -bool VPValue::isDefinedOutsideLoopRegions() const { - return !hasDefiningRecipe() || - !getDefiningRecipe()->getParent()->getEnclosingLoopRegion(); +/// Returns true if there is a vector loop region and \p VPV is defined in a +/// loop region. +static bool isDefinedInsideLoopRegions(const VPValue *VPV) { + const VPRecipeBase *DefR = VPV->getDefiningRecipe(); + return DefR && (!DefR->getParent()->getPlan()->getVectorLoopRegion() || + DefR->getParent()->getEnclosingLoopRegion()); } +bool VPValue::isDefinedOutsideLoopRegions() const { + return !isDefinedInsideLoopRegions(this); +} void VPValue::replaceAllUsesWith(VPValue *New) { replaceUsesWithIf(New, [](VPUser &, unsigned) { return true; }); } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 404202b..cfbb4ad 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -236,7 +236,8 @@ public: struct VPTransformState { VPTransformState(const TargetTransformInfo *TTI, ElementCount VF, unsigned UF, LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder, - InnerLoopVectorizer *ILV, VPlan *Plan, Type *CanonicalIVTy); + InnerLoopVectorizer *ILV, VPlan *Plan, + Loop *CurrentParentLoop, Type *CanonicalIVTy); /// Target Transform Info. const TargetTransformInfo *TTI; @@ -373,8 +374,8 @@ struct VPTransformState { /// Pointer to the VPlan code is generated for. VPlan *Plan; - /// The loop object for the current parent region, or nullptr. - Loop *CurrentVectorLoop = nullptr; + /// The parent loop object for the current scope, or nullptr. + Loop *CurrentParentLoop = nullptr; /// LoopVersioning. It's only set up (non-null) if memchecks were /// used. @@ -636,9 +637,6 @@ public: /// Return the cost of the block. virtual InstructionCost cost(ElementCount VF, VPCostContext &Ctx) = 0; - /// Delete all blocks reachable from a given VPBlockBase, inclusive. - static void deleteCFG(VPBlockBase *Entry); - /// Return true if it is legal to hoist instructions into this block. bool isLegalToHoistInto() { // There are currently no constraints that prevent an instruction to be @@ -646,10 +644,6 @@ public: return true; } - /// Replace all operands of VPUsers in the block with \p NewValue and also - /// replaces all uses of VPValues defined in the block with NewValue. - virtual void dropAllReferences(VPValue *NewValue) = 0; - #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printAsOperand(raw_ostream &OS, bool PrintType = false) const { OS << getName(); @@ -1357,6 +1351,9 @@ public: } } + /// Returns true if the underlying opcode may read from or write to memory. + bool opcodeMayReadOrWriteFromMemory() const; + /// Returns true if the recipe only uses the first lane of operand \p Op. bool onlyFirstLaneUsed(const VPValue *Op) const override; @@ -1586,14 +1583,16 @@ class VPScalarCastRecipe : public VPSingleDefRecipe { Value *generate(VPTransformState &State); public: - VPScalarCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy) - : VPSingleDefRecipe(VPDef::VPScalarCastSC, {Op}), Opcode(Opcode), + VPScalarCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy, + DebugLoc DL) + : VPSingleDefRecipe(VPDef::VPScalarCastSC, {Op}, DL), Opcode(Opcode), ResultTy(ResultTy) {} ~VPScalarCastRecipe() override = default; VPScalarCastRecipe *clone() override { - return new VPScalarCastRecipe(Opcode, getOperand(0), ResultTy); + return new VPScalarCastRecipe(Opcode, getOperand(0), ResultTy, + getDebugLoc()); } VP_CLASSOF_IMPL(VPDef::VPScalarCastSC) @@ -2101,6 +2100,15 @@ public: R->getVPDefID() == VPDef::VPWidenPointerInductionSC; } + static inline bool classof(const VPValue *V) { + auto *R = V->getDefiningRecipe(); + return R && classof(R); + } + + static inline bool classof(const VPHeaderPHIRecipe *R) { + return classof(static_cast<const VPRecipeBase *>(R)); + } + virtual void execute(VPTransformState &State) override = 0; /// Returns the step value of the induction. @@ -3556,8 +3564,6 @@ public: return make_range(begin(), getFirstNonPhi()); } - void dropAllReferences(VPValue *NewValue) override; - /// Split current block at \p SplitAt by inserting a new block between the /// current block and its successors and moving all recipes starting at /// SplitAt to the new block. Returns the new block. @@ -3587,12 +3593,7 @@ public: /// Clone the current block and it's recipes, without updating the operands of /// the cloned recipes. - VPBasicBlock *clone() override { - auto *NewBlock = new VPBasicBlock(getName()); - for (VPRecipeBase &R : *this) - NewBlock->appendRecipe(R.clone()); - return NewBlock; - } + VPBasicBlock *clone() override; protected: /// Execute the recipes in the IR basic block \p BB. @@ -3628,20 +3629,11 @@ public: return V->getVPBlockID() == VPBlockBase::VPIRBasicBlockSC; } - /// Create a VPIRBasicBlock from \p IRBB containing VPIRInstructions for all - /// instructions in \p IRBB, except its terminator which is managed in VPlan. - static VPIRBasicBlock *fromBasicBlock(BasicBlock *IRBB); - /// The method which generates the output IR instructions that correspond to /// this VPBasicBlock, thereby "executing" the VPlan. void execute(VPTransformState *State) override; - VPIRBasicBlock *clone() override { - auto *NewBlock = new VPIRBasicBlock(IRBB); - for (VPRecipeBase &R : Recipes) - NewBlock->appendRecipe(R.clone()); - return NewBlock; - } + VPIRBasicBlock *clone() override; BasicBlock *getIRBasicBlock() const { return IRBB; } }; @@ -3680,13 +3672,7 @@ public: : VPBlockBase(VPRegionBlockSC, Name), Entry(nullptr), Exiting(nullptr), IsReplicator(IsReplicator) {} - ~VPRegionBlock() override { - if (Entry) { - VPValue DummyValue; - Entry->dropAllReferences(&DummyValue); - deleteCFG(Entry); - } - } + ~VPRegionBlock() override {} /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPBlockBase *V) { @@ -3734,8 +3720,6 @@ public: // Return the cost of this region. InstructionCost cost(ElementCount VF, VPCostContext &Ctx) override; - void dropAllReferences(VPValue *NewValue) override; - #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print this VPRegionBlock to \p O (recursively), prefixing all lines with /// \p Indent. \p SlotTracker is used to print unnamed VPValue's using @@ -3812,6 +3796,10 @@ class VPlan { /// been modeled in VPlan directly. DenseMap<const SCEV *, VPValue *> SCEVToExpansion; + /// Blocks allocated and owned by the VPlan. They will be deleted once the + /// VPlan is destroyed. + SmallVector<VPBlockBase *> CreatedBlocks; + /// Construct a VPlan with \p Entry to the plan and with \p ScalarHeader /// wrapping the original header of the scalar loop. VPlan(VPBasicBlock *Entry, VPIRBasicBlock *ScalarHeader) @@ -3830,8 +3818,8 @@ public: /// Construct a VPlan with a new VPBasicBlock as entry, a VPIRBasicBlock /// wrapping \p ScalarHeaderBB and a trip count of \p TC. VPlan(BasicBlock *ScalarHeaderBB, VPValue *TC) { - setEntry(new VPBasicBlock("preheader")); - ScalarHeader = VPIRBasicBlock::fromBasicBlock(ScalarHeaderBB); + setEntry(createVPBasicBlock("preheader")); + ScalarHeader = createVPIRBasicBlock(ScalarHeaderBB); TripCount = TC; } @@ -3870,9 +3858,13 @@ public: VPBasicBlock *getEntry() { return Entry; } const VPBasicBlock *getEntry() const { return Entry; } - /// Returns the preheader of the vector loop region. + /// Returns the preheader of the vector loop region, if one exists, or null + /// otherwise. VPBasicBlock *getVectorPreheader() { - return cast<VPBasicBlock>(getVectorLoopRegion()->getSinglePredecessor()); + VPRegionBlock *VectorRegion = getVectorLoopRegion(); + return VectorRegion + ? cast<VPBasicBlock>(VectorRegion->getSinglePredecessor()) + : nullptr; } /// Returns the VPRegionBlock of the vector loop. @@ -4029,6 +4021,49 @@ public: /// Clone the current VPlan, update all VPValues of the new VPlan and cloned /// recipes to refer to the clones, and return it. VPlan *duplicate(); + + /// Create a new VPBasicBlock with \p Name and containing \p Recipe if + /// present. The returned block is owned by the VPlan and deleted once the + /// VPlan is destroyed. + VPBasicBlock *createVPBasicBlock(const Twine &Name, + VPRecipeBase *Recipe = nullptr) { + auto *VPB = new VPBasicBlock(Name, Recipe); + CreatedBlocks.push_back(VPB); + return VPB; + } + + /// Create a new VPRegionBlock with \p Entry, \p Exiting and \p Name. If \p + /// IsReplicator is true, the region is a replicate region. The returned block + /// is owned by the VPlan and deleted once the VPlan is destroyed. + VPRegionBlock *createVPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exiting, + const std::string &Name = "", + bool IsReplicator = false) { + auto *VPB = new VPRegionBlock(Entry, Exiting, Name, IsReplicator); + CreatedBlocks.push_back(VPB); + return VPB; + } + + /// Create a new VPRegionBlock with \p Name and entry and exiting blocks set + /// to nullptr. If \p IsReplicator is true, the region is a replicate region. + /// The returned block is owned by the VPlan and deleted once the VPlan is + /// destroyed. + VPRegionBlock *createVPRegionBlock(const std::string &Name = "", + bool IsReplicator = false) { + auto *VPB = new VPRegionBlock(Name, IsReplicator); + CreatedBlocks.push_back(VPB); + return VPB; + } + + /// Create a VPIRBasicBlock wrapping \p IRBB, but do not create + /// VPIRInstructions wrapping the instructions in t\p IRBB. The returned + /// block is owned by the VPlan and deleted once the VPlan is destroyed. + VPIRBasicBlock *createEmptyVPIRBasicBlock(BasicBlock *IRBB); + + /// Create a VPIRBasicBlock from \p IRBB containing VPIRInstructions for all + /// instructions in \p IRBB, except its terminator which is managed by the + /// successors of the block in VPlan. The returned block is owned by the VPlan + /// and deleted once the VPlan is destroyed. + VPIRBasicBlock *createVPIRBasicBlock(BasicBlock *IRBB); }; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp index 6e63373..76ed578 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -182,7 +182,7 @@ VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) { // Create new VPBB. StringRef Name = isHeaderBB(BB, TheLoop) ? "vector.body" : BB->getName(); LLVM_DEBUG(dbgs() << "Creating VPBasicBlock for " << Name << "\n"); - VPBasicBlock *VPBB = new VPBasicBlock(Name); + VPBasicBlock *VPBB = Plan.createVPBasicBlock(Name); BB2VPBB[BB] = VPBB; // Get or create a region for the loop containing BB. @@ -204,7 +204,7 @@ VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) { if (LoopOfBB == TheLoop) { RegionOfVPBB = Plan.getVectorLoopRegion(); } else { - RegionOfVPBB = new VPRegionBlock(Name.str(), false /*isReplicator*/); + RegionOfVPBB = Plan.createVPRegionBlock(Name.str(), false /*isReplicator*/); RegionOfVPBB->setParent(Loop2Region[LoopOfBB->getParentLoop()]); } RegionOfVPBB->setEntry(VPBB); @@ -357,12 +357,10 @@ void PlainCFGBuilder::buildPlainCFG() { BB2VPBB[TheLoop->getHeader()] = VectorHeaderVPBB; VectorHeaderVPBB->clearSuccessors(); VectorLatchVPBB->clearPredecessors(); - if (TheLoop->getHeader() != TheLoop->getLoopLatch()) { + if (TheLoop->getHeader() != TheLoop->getLoopLatch()) BB2VPBB[TheLoop->getLoopLatch()] = VectorLatchVPBB; - } else { + else TheRegion->setExiting(VectorHeaderVPBB); - delete VectorLatchVPBB; - } // 1. Scan the body of the loop in a topological order to visit each basic // block after having visited its predecessor basic blocks. Create a VPBB for diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index ec3c203..4866426 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -139,7 +139,8 @@ struct MatchRecipeAndOpcode<Opcode, RecipeTy> { if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value || std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value || std::is_same<RecipeTy, VPWidenSelectRecipe>::value || - std::is_same<RecipeTy, VPDerivedIVRecipe>::value) + std::is_same<RecipeTy, VPDerivedIVRecipe>::value || + std::is_same<RecipeTy, VPWidenGEPRecipe>::value) return DefR; else return DefR && DefR->getOpcode() == Opcode; @@ -309,6 +310,12 @@ m_Binary(const Op0_t &Op0, const Op1_t &Op1) { return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1); } +template <unsigned Opcode, typename Op0_t, typename Op1_t> +inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, true> +m_c_Binary(const Op0_t &Op0, const Op1_t &Op1) { + return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, true>(Op0, Op1); +} + template <typename Op0_t, typename Op1_t> inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul> m_Mul(const Op0_t &Op0, const Op1_t &Op1) { @@ -339,6 +346,18 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1); } +template <typename Op0_t, typename Op1_t> +using GEPLikeRecipe_match = + BinaryRecipe_match<Op0_t, Op1_t, Instruction::GetElementPtr, false, + VPWidenRecipe, VPReplicateRecipe, VPWidenGEPRecipe, + VPInstruction>; + +template <typename Op0_t, typename Op1_t> +inline GEPLikeRecipe_match<Op0_t, Op1_t> m_GetElementPtr(const Op0_t &Op0, + const Op1_t &Op1) { + return GEPLikeRecipe_match<Op0_t, Op1_t>(Op0, Op1); +} + template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode> using AllTernaryRecipe_match = Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, Opcode, false, diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 7038e52..e54df8bd 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -51,24 +51,7 @@ extern cl::opt<unsigned> ForceTargetInstructionCost; bool VPRecipeBase::mayWriteToMemory() const { switch (getVPDefID()) { case VPInstructionSC: - if (Instruction::isBinaryOp(cast<VPInstruction>(this)->getOpcode())) - return false; - switch (cast<VPInstruction>(this)->getOpcode()) { - case Instruction::Or: - case Instruction::ICmp: - case Instruction::Select: - case VPInstruction::AnyOf: - case VPInstruction::Not: - case VPInstruction::CalculateTripCountMinusVF: - case VPInstruction::CanonicalIVIncrementForPart: - case VPInstruction::ExtractFromEnd: - case VPInstruction::FirstOrderRecurrenceSplice: - case VPInstruction::LogicalAnd: - case VPInstruction::PtrAdd: - return false; - default: - return true; - } + return cast<VPInstruction>(this)->opcodeMayReadOrWriteFromMemory(); case VPInterleaveSC: return cast<VPInterleaveRecipe>(this)->getNumStoreOperands() > 0; case VPWidenStoreEVLSC: @@ -115,6 +98,8 @@ bool VPRecipeBase::mayWriteToMemory() const { bool VPRecipeBase::mayReadFromMemory() const { switch (getVPDefID()) { + case VPInstructionSC: + return cast<VPInstruction>(this)->opcodeMayReadOrWriteFromMemory(); case VPWidenLoadEVLSC: case VPWidenLoadSC: return true; @@ -707,6 +692,26 @@ void VPInstruction::execute(VPTransformState &State) { /*IsScalar*/ GeneratesPerFirstLaneOnly); } +bool VPInstruction::opcodeMayReadOrWriteFromMemory() const { + if (Instruction::isBinaryOp(getOpcode())) + return false; + switch (getOpcode()) { + case Instruction::ICmp: + case Instruction::Select: + case VPInstruction::AnyOf: + case VPInstruction::CalculateTripCountMinusVF: + case VPInstruction::CanonicalIVIncrementForPart: + case VPInstruction::ExtractFromEnd: + case VPInstruction::FirstOrderRecurrenceSplice: + case VPInstruction::LogicalAnd: + case VPInstruction::Not: + case VPInstruction::PtrAdd: + return false; + default: + return true; + } +} + bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const { assert(is_contained(operands(), Op) && "Op must be an operand of the recipe"); if (Instruction::isBinaryOp(getOpcode())) @@ -1352,10 +1357,9 @@ void VPWidenRecipe::execute(VPTransformState &State) { Value *C = nullptr; if (FCmp) { // Propagate fast math flags. - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - if (auto *I = dyn_cast_or_null<Instruction>(getUnderlyingValue())) - Builder.setFastMathFlags(I->getFastMathFlags()); - C = Builder.CreateFCmp(getPredicate(), A, B); + C = Builder.CreateFCmpFMF( + getPredicate(), A, B, + dyn_cast_or_null<Instruction>(getUnderlyingValue())); } else { C = Builder.CreateICmp(getPredicate(), A, B); } @@ -2328,6 +2332,7 @@ void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent, #endif Value *VPScalarCastRecipe ::generate(VPTransformState &State) { + State.setDebugLocFrom(getDebugLoc()); assert(vputils::onlyFirstLaneUsed(this) && "Codegen only implemented for first lane."); switch (Opcode) { @@ -2789,21 +2794,10 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals, // Scalable vectors cannot use arbitrary shufflevectors (only splats), so // must use intrinsics to interleave. if (VecTy->isScalableTy()) { - assert(isPowerOf2_32(Factor) && "Unsupported interleave factor for " - "scalable vectors, must be power of 2"); - SmallVector<Value *> InterleavingValues(Vals); - // When interleaving, the number of values will be shrunk until we have the - // single final interleaved value. - auto *InterleaveTy = cast<VectorType>(InterleavingValues[0]->getType()); - for (unsigned Midpoint = Factor / 2; Midpoint > 0; Midpoint /= 2) { - InterleaveTy = VectorType::getDoubleElementsVectorType(InterleaveTy); - for (unsigned I = 0; I < Midpoint; ++I) - InterleavingValues[I] = Builder.CreateIntrinsic( - InterleaveTy, Intrinsic::vector_interleave2, - {InterleavingValues[I], InterleavingValues[Midpoint + I]}, - /*FMFSource=*/nullptr, Name); - } - return InterleavingValues[0]; + VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy); + return Builder.CreateIntrinsic(WideVecTy, Intrinsic::vector_interleave2, + Vals, + /*FMFSource=*/nullptr, Name); } // Fixed length. Start by concatenating all vectors into a wide vector. @@ -2889,11 +2883,15 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { &InterleaveFactor](Value *MaskForGaps) -> Value * { if (State.VF.isScalable()) { assert(!MaskForGaps && "Interleaved groups with gaps are not supported."); - assert(isPowerOf2_32(InterleaveFactor) && + assert(InterleaveFactor == 2 && "Unsupported deinterleave factor for scalable vectors"); auto *ResBlockInMask = State.get(BlockInMask); - SmallVector<Value *> Ops(InterleaveFactor, ResBlockInMask); - return interleaveVectors(State.Builder, Ops, "interleaved.mask"); + SmallVector<Value *, 2> Ops = {ResBlockInMask, ResBlockInMask}; + auto *MaskTy = VectorType::get(State.Builder.getInt1Ty(), + State.VF.getKnownMinValue() * 2, true); + return State.Builder.CreateIntrinsic( + MaskTy, Intrinsic::vector_interleave2, Ops, + /*FMFSource=*/nullptr, "interleaved.mask"); } if (!BlockInMask) @@ -2933,48 +2931,22 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { ArrayRef<VPValue *> VPDefs = definedValues(); const DataLayout &DL = State.CFG.PrevBB->getDataLayout(); if (VecTy->isScalableTy()) { - assert(isPowerOf2_32(InterleaveFactor) && + assert(InterleaveFactor == 2 && "Unsupported deinterleave factor for scalable vectors"); - // Scalable vectors cannot use arbitrary shufflevectors (only splats), - // so must use intrinsics to deinterleave. - SmallVector<Value *> DeinterleavedValues(InterleaveFactor); - DeinterleavedValues[0] = NewLoad; - // For the case of InterleaveFactor > 2, we will have to do recursive - // deinterleaving, because the current available deinterleave intrinsic - // supports only Factor of 2, otherwise it will bailout after first - // iteration. - // When deinterleaving, the number of values will double until we - // have "InterleaveFactor". - for (unsigned NumVectors = 1; NumVectors < InterleaveFactor; - NumVectors *= 2) { - // Deinterleave the elements within the vector - SmallVector<Value *> TempDeinterleavedValues(NumVectors); - for (unsigned I = 0; I < NumVectors; ++I) { - auto *DiTy = DeinterleavedValues[I]->getType(); - TempDeinterleavedValues[I] = State.Builder.CreateIntrinsic( - Intrinsic::vector_deinterleave2, DiTy, DeinterleavedValues[I], - /*FMFSource=*/nullptr, "strided.vec"); - } - // Extract the deinterleaved values: - for (unsigned I = 0; I < 2; ++I) - for (unsigned J = 0; J < NumVectors; ++J) - DeinterleavedValues[NumVectors * I + J] = - State.Builder.CreateExtractValue(TempDeinterleavedValues[J], I); - } - -#ifndef NDEBUG - for (Value *Val : DeinterleavedValues) - assert(Val && "NULL Deinterleaved Value"); -#endif - for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) { + // Scalable vectors cannot use arbitrary shufflevectors (only splats), + // so must use intrinsics to deinterleave. + Value *DI = State.Builder.CreateIntrinsic( + Intrinsic::vector_deinterleave2, VecTy, NewLoad, + /*FMFSource=*/nullptr, "strided.vec"); + unsigned J = 0; + for (unsigned I = 0; I < InterleaveFactor; ++I) { Instruction *Member = Group->getMember(I); - Value *StridedVec = DeinterleavedValues[I]; - if (!Member) { - // This value is not needed as it's not used - static_cast<Instruction *>(StridedVec)->eraseFromParent(); + + if (!Member) continue; - } + + Value *StridedVec = State.Builder.CreateExtractValue(DI, I); // If this member has different type, cast the result type. if (Member->getType() != ScalarTy) { VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF); @@ -3398,7 +3370,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) { : VectorType::get(StartV->getType(), State.VF); BasicBlock *HeaderBB = State.CFG.PrevBB; - assert(State.CurrentVectorLoop->getHeader() == HeaderBB && + assert(State.CurrentParentLoop->getHeader() == HeaderBB && "recipe must be in the vector loop header"); auto *Phi = PHINode::Create(VecTy, 2, "vec.phi"); Phi->insertBefore(HeaderBB->getFirstInsertionPt()); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 0b809c2..3e3f5ad 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -217,7 +217,7 @@ static VPBasicBlock *getPredicatedThenBlock(VPRegionBlock *R) { // is connected to a successor replicate region with the same predicate by a // single, empty VPBasicBlock. static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { - SetVector<VPRegionBlock *> DeletedRegions; + SmallPtrSet<VPRegionBlock *, 4> TransformedRegions; // Collect replicate regions followed by an empty block, followed by another // replicate region with matching masks to process front. This is to avoid @@ -248,7 +248,7 @@ static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { // Move recipes from Region1 to its successor region, if both are triangles. for (VPRegionBlock *Region1 : WorkList) { - if (DeletedRegions.contains(Region1)) + if (TransformedRegions.contains(Region1)) continue; auto *MiddleBasicBlock = cast<VPBasicBlock>(Region1->getSingleSuccessor()); auto *Region2 = cast<VPRegionBlock>(MiddleBasicBlock->getSingleSuccessor()); @@ -294,12 +294,10 @@ static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { VPBlockUtils::connectBlocks(Pred, MiddleBasicBlock); } VPBlockUtils::disconnectBlocks(Region1, MiddleBasicBlock); - DeletedRegions.insert(Region1); + TransformedRegions.insert(Region1); } - for (VPRegionBlock *ToDelete : DeletedRegions) - delete ToDelete; - return !DeletedRegions.empty(); + return !TransformedRegions.empty(); } static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe, @@ -310,7 +308,8 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe, assert(Instr->getParent() && "Predicated instruction not in any basic block"); auto *BlockInMask = PredRecipe->getMask(); auto *BOMRecipe = new VPBranchOnMaskRecipe(BlockInMask); - auto *Entry = new VPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe); + auto *Entry = + Plan.createVPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe); // Replace predicated replicate recipe with a replicate recipe without a // mask but in the replicate region. @@ -318,7 +317,8 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe, PredRecipe->getUnderlyingInstr(), make_range(PredRecipe->op_begin(), std::prev(PredRecipe->op_end())), PredRecipe->isUniform()); - auto *Pred = new VPBasicBlock(Twine(RegionName) + ".if", RecipeWithoutMask); + auto *Pred = + Plan.createVPBasicBlock(Twine(RegionName) + ".if", RecipeWithoutMask); VPPredInstPHIRecipe *PHIRecipe = nullptr; if (PredRecipe->getNumUsers() != 0) { @@ -328,8 +328,10 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe, PHIRecipe->setOperand(0, RecipeWithoutMask); } PredRecipe->eraseFromParent(); - auto *Exiting = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe); - VPRegionBlock *Region = new VPRegionBlock(Entry, Exiting, RegionName, true); + auto *Exiting = + Plan.createVPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe); + VPRegionBlock *Region = + Plan.createVPRegionBlock(Entry, Exiting, RegionName, true); // Note: first set Entry as region entry and then connect successors starting // from it in order, to propagate the "parent" of each VPBasicBlock. @@ -396,7 +398,7 @@ static bool mergeBlocksIntoPredecessors(VPlan &Plan) { VPBlockUtils::disconnectBlocks(VPBB, Succ); VPBlockUtils::connectBlocks(PredVPBB, Succ); } - delete VPBB; + // VPBB is now dead and will be cleaned up when the plan gets destroyed. } return !WorkList.empty(); } @@ -525,7 +527,8 @@ static VPScalarIVStepsRecipe * createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind, Instruction::BinaryOps InductionOpcode, FPMathOperator *FPBinOp, Instruction *TruncI, - VPValue *StartV, VPValue *Step, VPBuilder &Builder) { + VPValue *StartV, VPValue *Step, DebugLoc DL, + VPBuilder &Builder) { VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV(); VPSingleDefRecipe *BaseIV = Builder.createDerivedIV( @@ -540,7 +543,7 @@ createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind, assert(ResultTy->getScalarSizeInBits() > TruncTy->getScalarSizeInBits() && "Not truncating."); assert(ResultTy->isIntegerTy() && "Truncation requires an integer type"); - BaseIV = Builder.createScalarCast(Instruction::Trunc, BaseIV, TruncTy); + BaseIV = Builder.createScalarCast(Instruction::Trunc, BaseIV, TruncTy, DL); ResultTy = TruncTy; } @@ -554,26 +557,68 @@ createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind, cast<VPBasicBlock>(HeaderVPBB->getSingleHierarchicalPredecessor()); VPBuilder::InsertPointGuard Guard(Builder); Builder.setInsertPoint(VecPreheader); - Step = Builder.createScalarCast(Instruction::Trunc, Step, ResultTy); + Step = Builder.createScalarCast(Instruction::Trunc, Step, ResultTy, DL); } return Builder.createScalarIVSteps(InductionOpcode, FPBinOp, BaseIV, Step); } +static SmallVector<VPUser *> collectUsersRecursively(VPValue *V) { + SetVector<VPUser *> Users(V->user_begin(), V->user_end()); + for (unsigned I = 0; I != Users.size(); ++I) { + VPRecipeBase *Cur = cast<VPRecipeBase>(Users[I]); + if (isa<VPHeaderPHIRecipe>(Cur)) + continue; + for (VPValue *V : Cur->definedValues()) + Users.insert(V->user_begin(), V->user_end()); + } + return Users.takeVector(); +} + /// Legalize VPWidenPointerInductionRecipe, by replacing it with a PtrAdd /// (IndStart, ScalarIVSteps (0, Step)) if only its scalar values are used, as /// VPWidenPointerInductionRecipe will generate vectors only. If some users /// require vectors while other require scalars, the scalar uses need to extract /// the scalars from the generated vectors (Note that this is different to how -/// int/fp inductions are handled). Also optimize VPWidenIntOrFpInductionRecipe, -/// if any of its users needs scalar values, by providing them scalar steps -/// built on the canonical scalar IV and update the original IV's users. This is -/// an optional optimization to reduce the needs of vector extracts. +/// int/fp inductions are handled). Legalize extract-from-ends using uniform +/// VPReplicateRecipe of wide inductions to use regular VPReplicateRecipe, so +/// the correct end value is available. Also optimize +/// VPWidenIntOrFpInductionRecipe, if any of its users needs scalar values, by +/// providing them scalar steps built on the canonical scalar IV and update the +/// original IV's users. This is an optional optimization to reduce the needs of +/// vector extracts. static void legalizeAndOptimizeInductions(VPlan &Plan) { + using namespace llvm::VPlanPatternMatch; SmallVector<VPRecipeBase *> ToRemove; VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); bool HasOnlyVectorVFs = !Plan.hasVF(ElementCount::getFixed(1)); VPBuilder Builder(HeaderVPBB, HeaderVPBB->getFirstNonPhi()); for (VPRecipeBase &Phi : HeaderVPBB->phis()) { + auto *PhiR = dyn_cast<VPHeaderPHIRecipe>(&Phi); + if (!PhiR) + break; + + // Check if any uniform VPReplicateRecipes using the phi recipe are used by + // ExtractFromEnd. Those must be replaced by a regular VPReplicateRecipe to + // ensure the final value is available. + // TODO: Remove once uniformity analysis is done on VPlan. + for (VPUser *U : collectUsersRecursively(PhiR)) { + auto *ExitIRI = dyn_cast<VPIRInstruction>(U); + VPValue *Op; + if (!ExitIRI || !match(ExitIRI->getOperand(0), + m_VPInstruction<VPInstruction::ExtractFromEnd>( + m_VPValue(Op), m_VPValue()))) + continue; + auto *RepR = dyn_cast<VPReplicateRecipe>(Op); + if (!RepR || !RepR->isUniform()) + continue; + assert(!RepR->isPredicated() && "RepR must not be predicated"); + Instruction *I = RepR->getUnderlyingInstr(); + auto *Clone = + new VPReplicateRecipe(I, RepR->operands(), /*IsUniform*/ false); + Clone->insertAfter(RepR); + RepR->replaceAllUsesWith(Clone); + } + // Replace wide pointer inductions which have only their scalars used by // PtrAdd(IndStart, ScalarIVSteps (0, Step)). if (auto *PtrIV = dyn_cast<VPWidenPointerInductionRecipe>(&Phi)) { @@ -586,7 +631,7 @@ static void legalizeAndOptimizeInductions(VPlan &Plan) { VPValue *StepV = PtrIV->getOperand(1); VPScalarIVStepsRecipe *Steps = createScalarIVSteps( Plan, InductionDescriptor::IK_IntInduction, Instruction::Add, nullptr, - nullptr, StartV, StepV, Builder); + nullptr, StartV, StepV, PtrIV->getDebugLoc(), Builder); VPValue *PtrAdd = Builder.createPtrAdd(PtrIV->getStartValue(), Steps, PtrIV->getDebugLoc(), "next.gep"); @@ -610,7 +655,7 @@ static void legalizeAndOptimizeInductions(VPlan &Plan) { Plan, ID.getKind(), ID.getInductionOpcode(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()), WideIV->getTruncInst(), WideIV->getStartValue(), WideIV->getStepValue(), - Builder); + WideIV->getDebugLoc(), Builder); // Update scalar users of IV to use Step instead. if (!HasOnlyVectorVFs) @@ -660,13 +705,158 @@ static void recursivelyDeleteDeadRecipes(VPValue *V) { } } +/// Try to simplify recipe \p R. +static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { + using namespace llvm::VPlanPatternMatch; + + if (auto *Blend = dyn_cast<VPBlendRecipe>(&R)) { + // Try to remove redundant blend recipes. + SmallPtrSet<VPValue *, 4> UniqueValues; + if (Blend->isNormalized() || !match(Blend->getMask(0), m_False())) + UniqueValues.insert(Blend->getIncomingValue(0)); + for (unsigned I = 1; I != Blend->getNumIncomingValues(); ++I) + if (!match(Blend->getMask(I), m_False())) + UniqueValues.insert(Blend->getIncomingValue(I)); + + if (UniqueValues.size() == 1) { + Blend->replaceAllUsesWith(*UniqueValues.begin()); + Blend->eraseFromParent(); + return; + } + + if (Blend->isNormalized()) + return; + + // Normalize the blend so its first incoming value is used as the initial + // value with the others blended into it. + + unsigned StartIndex = 0; + for (unsigned I = 0; I != Blend->getNumIncomingValues(); ++I) { + // If a value's mask is used only by the blend then is can be deadcoded. + // TODO: Find the most expensive mask that can be deadcoded, or a mask + // that's used by multiple blends where it can be removed from them all. + VPValue *Mask = Blend->getMask(I); + if (Mask->getNumUsers() == 1 && !match(Mask, m_False())) { + StartIndex = I; + break; + } + } + + SmallVector<VPValue *, 4> OperandsWithMask; + OperandsWithMask.push_back(Blend->getIncomingValue(StartIndex)); + + for (unsigned I = 0; I != Blend->getNumIncomingValues(); ++I) { + if (I == StartIndex) + continue; + OperandsWithMask.push_back(Blend->getIncomingValue(I)); + OperandsWithMask.push_back(Blend->getMask(I)); + } + + auto *NewBlend = new VPBlendRecipe( + cast<PHINode>(Blend->getUnderlyingValue()), OperandsWithMask); + NewBlend->insertBefore(&R); + + VPValue *DeadMask = Blend->getMask(StartIndex); + Blend->replaceAllUsesWith(NewBlend); + Blend->eraseFromParent(); + recursivelyDeleteDeadRecipes(DeadMask); + return; + } + + VPValue *A; + if (match(&R, m_Trunc(m_ZExtOrSExt(m_VPValue(A))))) { + VPValue *Trunc = R.getVPSingleValue(); + Type *TruncTy = TypeInfo.inferScalarType(Trunc); + Type *ATy = TypeInfo.inferScalarType(A); + if (TruncTy == ATy) { + Trunc->replaceAllUsesWith(A); + } else { + // Don't replace a scalarizing recipe with a widened cast. + if (isa<VPReplicateRecipe>(&R)) + return; + if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) { + + unsigned ExtOpcode = match(R.getOperand(0), m_SExt(m_VPValue())) + ? Instruction::SExt + : Instruction::ZExt; + auto *VPC = + new VPWidenCastRecipe(Instruction::CastOps(ExtOpcode), A, TruncTy); + if (auto *UnderlyingExt = R.getOperand(0)->getUnderlyingValue()) { + // UnderlyingExt has distinct return type, used to retain legacy cost. + VPC->setUnderlyingValue(UnderlyingExt); + } + VPC->insertBefore(&R); + Trunc->replaceAllUsesWith(VPC); + } else if (ATy->getScalarSizeInBits() > TruncTy->getScalarSizeInBits()) { + auto *VPC = new VPWidenCastRecipe(Instruction::Trunc, A, TruncTy); + VPC->insertBefore(&R); + Trunc->replaceAllUsesWith(VPC); + } + } +#ifndef NDEBUG + // Verify that the cached type info is for both A and its users is still + // accurate by comparing it to freshly computed types. + VPTypeAnalysis TypeInfo2( + R.getParent()->getPlan()->getCanonicalIV()->getScalarType()); + assert(TypeInfo.inferScalarType(A) == TypeInfo2.inferScalarType(A)); + for (VPUser *U : A->users()) { + auto *R = cast<VPRecipeBase>(U); + for (VPValue *VPV : R->definedValues()) + assert(TypeInfo.inferScalarType(VPV) == TypeInfo2.inferScalarType(VPV)); + } +#endif + } + + // Simplify (X && Y) || (X && !Y) -> X. + // TODO: Split up into simpler, modular combines: (X && Y) || (X && Z) into X + // && (Y || Z) and (X || !X) into true. This requires queuing newly created + // recipes to be visited during simplification. + VPValue *X, *Y, *X1, *Y1; + if (match(&R, + m_c_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)), + m_LogicalAnd(m_VPValue(X1), m_Not(m_VPValue(Y1))))) && + X == X1 && Y == Y1) { + R.getVPSingleValue()->replaceAllUsesWith(X); + R.eraseFromParent(); + return; + } + + if (match(&R, m_c_Mul(m_VPValue(A), m_SpecificInt(1)))) + return R.getVPSingleValue()->replaceAllUsesWith(A); + + if (match(&R, m_Not(m_Not(m_VPValue(A))))) + return R.getVPSingleValue()->replaceAllUsesWith(A); + + // Remove redundant DerviedIVs, that is 0 + A * 1 -> A and 0 + 0 * x -> 0. + if ((match(&R, + m_DerivedIV(m_SpecificInt(0), m_VPValue(A), m_SpecificInt(1))) || + match(&R, + m_DerivedIV(m_SpecificInt(0), m_SpecificInt(0), m_VPValue()))) && + TypeInfo.inferScalarType(R.getOperand(1)) == + TypeInfo.inferScalarType(R.getVPSingleValue())) + return R.getVPSingleValue()->replaceAllUsesWith(R.getOperand(1)); +} + +/// Try to simplify the recipes in \p Plan. Use \p CanonicalIVTy as type for all +/// un-typed live-ins in VPTypeAnalysis. +static void simplifyRecipes(VPlan &Plan, Type *CanonicalIVTy) { + ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( + Plan.getEntry()); + VPTypeAnalysis TypeInfo(CanonicalIVTy); + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) { + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + simplifyRecipe(R, TypeInfo); + } + } +} + void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, unsigned BestUF, PredicatedScalarEvolution &PSE) { assert(Plan.hasVF(BestVF) && "BestVF is not available in Plan"); assert(Plan.hasUF(BestUF) && "BestUF is not available in Plan"); - VPBasicBlock *ExitingVPBB = - Plan.getVectorLoopRegion()->getExitingBasicBlock(); + VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); + VPBasicBlock *ExitingVPBB = VectorRegion->getExitingBasicBlock(); auto *Term = &ExitingVPBB->back(); // Try to simplify the branch condition if TC <= VF * UF when preparing to // execute the plan for the main vector loop. We only do this if the @@ -690,16 +880,44 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, !SE.isKnownPredicate(CmpInst::ICMP_ULE, TripCount, C)) return; - LLVMContext &Ctx = SE.getContext(); - auto *BOC = new VPInstruction( - VPInstruction::BranchOnCond, - {Plan.getOrAddLiveIn(ConstantInt::getTrue(Ctx))}, Term->getDebugLoc()); + // The vector loop region only executes once. If possible, completely remove + // the region, otherwise replace the terminator controlling the latch with + // (BranchOnCond true). + auto *Header = cast<VPBasicBlock>(VectorRegion->getEntry()); + auto *CanIVTy = Plan.getCanonicalIV()->getScalarType(); + if (all_of( + Header->phis(), + IsaPred<VPCanonicalIVPHIRecipe, VPFirstOrderRecurrencePHIRecipe>)) { + for (VPRecipeBase &HeaderR : make_early_inc_range(Header->phis())) { + auto *HeaderPhiR = cast<VPHeaderPHIRecipe>(&HeaderR); + HeaderPhiR->replaceAllUsesWith(HeaderPhiR->getStartValue()); + HeaderPhiR->eraseFromParent(); + } + + VPBlockBase *Preheader = VectorRegion->getSinglePredecessor(); + VPBlockBase *Exit = VectorRegion->getSingleSuccessor(); + VPBlockUtils::disconnectBlocks(Preheader, VectorRegion); + VPBlockUtils::disconnectBlocks(VectorRegion, Exit); + + for (VPBlockBase *B : vp_depth_first_shallow(VectorRegion->getEntry())) + B->setParent(nullptr); + + VPBlockUtils::connectBlocks(Preheader, Header); + VPBlockUtils::connectBlocks(ExitingVPBB, Exit); + simplifyRecipes(Plan, CanIVTy); + } else { + // The vector region contains header phis for which we cannot remove the + // loop region yet. + LLVMContext &Ctx = SE.getContext(); + auto *BOC = new VPInstruction( + VPInstruction::BranchOnCond, + {Plan.getOrAddLiveIn(ConstantInt::getTrue(Ctx))}, Term->getDebugLoc()); + ExitingVPBB->appendRecipe(BOC); + } - SmallVector<VPValue *> PossiblyDead(Term->operands()); Term->eraseFromParent(); - for (VPValue *Op : PossiblyDead) - recursivelyDeleteDeadRecipes(Op); - ExitingVPBB->appendRecipe(BOC); + VPlanTransforms::removeDeadRecipes(Plan); + Plan.setVF(BestVF); Plan.setUF(BestUF); // TODO: Further simplifications are possible @@ -910,18 +1128,6 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan, return true; } -static SmallVector<VPUser *> collectUsersRecursively(VPValue *V) { - SetVector<VPUser *> Users(V->user_begin(), V->user_end()); - for (unsigned I = 0; I != Users.size(); ++I) { - VPRecipeBase *Cur = cast<VPRecipeBase>(Users[I]); - if (isa<VPHeaderPHIRecipe>(Cur)) - continue; - for (VPValue *V : Cur->definedValues()) - Users.insert(V->user_begin(), V->user_end()); - } - return Users.takeVector(); -} - void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) { for (VPRecipeBase &R : Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { @@ -940,138 +1146,6 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) { } } -/// Try to simplify recipe \p R. -static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { - using namespace llvm::VPlanPatternMatch; - - if (auto *Blend = dyn_cast<VPBlendRecipe>(&R)) { - // Try to remove redundant blend recipes. - SmallPtrSet<VPValue *, 4> UniqueValues; - if (Blend->isNormalized() || !match(Blend->getMask(0), m_False())) - UniqueValues.insert(Blend->getIncomingValue(0)); - for (unsigned I = 1; I != Blend->getNumIncomingValues(); ++I) - if (!match(Blend->getMask(I), m_False())) - UniqueValues.insert(Blend->getIncomingValue(I)); - - if (UniqueValues.size() == 1) { - Blend->replaceAllUsesWith(*UniqueValues.begin()); - Blend->eraseFromParent(); - return; - } - - if (Blend->isNormalized()) - return; - - // Normalize the blend so its first incoming value is used as the initial - // value with the others blended into it. - - unsigned StartIndex = 0; - for (unsigned I = 0; I != Blend->getNumIncomingValues(); ++I) { - // If a value's mask is used only by the blend then is can be deadcoded. - // TODO: Find the most expensive mask that can be deadcoded, or a mask - // that's used by multiple blends where it can be removed from them all. - VPValue *Mask = Blend->getMask(I); - if (Mask->getNumUsers() == 1 && !match(Mask, m_False())) { - StartIndex = I; - break; - } - } - - SmallVector<VPValue *, 4> OperandsWithMask; - OperandsWithMask.push_back(Blend->getIncomingValue(StartIndex)); - - for (unsigned I = 0; I != Blend->getNumIncomingValues(); ++I) { - if (I == StartIndex) - continue; - OperandsWithMask.push_back(Blend->getIncomingValue(I)); - OperandsWithMask.push_back(Blend->getMask(I)); - } - - auto *NewBlend = new VPBlendRecipe( - cast<PHINode>(Blend->getUnderlyingValue()), OperandsWithMask); - NewBlend->insertBefore(&R); - - VPValue *DeadMask = Blend->getMask(StartIndex); - Blend->replaceAllUsesWith(NewBlend); - Blend->eraseFromParent(); - recursivelyDeleteDeadRecipes(DeadMask); - return; - } - - VPValue *A; - if (match(&R, m_Trunc(m_ZExtOrSExt(m_VPValue(A))))) { - VPValue *Trunc = R.getVPSingleValue(); - Type *TruncTy = TypeInfo.inferScalarType(Trunc); - Type *ATy = TypeInfo.inferScalarType(A); - if (TruncTy == ATy) { - Trunc->replaceAllUsesWith(A); - } else { - // Don't replace a scalarizing recipe with a widened cast. - if (isa<VPReplicateRecipe>(&R)) - return; - if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) { - - unsigned ExtOpcode = match(R.getOperand(0), m_SExt(m_VPValue())) - ? Instruction::SExt - : Instruction::ZExt; - auto *VPC = - new VPWidenCastRecipe(Instruction::CastOps(ExtOpcode), A, TruncTy); - if (auto *UnderlyingExt = R.getOperand(0)->getUnderlyingValue()) { - // UnderlyingExt has distinct return type, used to retain legacy cost. - VPC->setUnderlyingValue(UnderlyingExt); - } - VPC->insertBefore(&R); - Trunc->replaceAllUsesWith(VPC); - } else if (ATy->getScalarSizeInBits() > TruncTy->getScalarSizeInBits()) { - auto *VPC = new VPWidenCastRecipe(Instruction::Trunc, A, TruncTy); - VPC->insertBefore(&R); - Trunc->replaceAllUsesWith(VPC); - } - } -#ifndef NDEBUG - // Verify that the cached type info is for both A and its users is still - // accurate by comparing it to freshly computed types. - VPTypeAnalysis TypeInfo2( - R.getParent()->getPlan()->getCanonicalIV()->getScalarType()); - assert(TypeInfo.inferScalarType(A) == TypeInfo2.inferScalarType(A)); - for (VPUser *U : A->users()) { - auto *R = cast<VPRecipeBase>(U); - for (VPValue *VPV : R->definedValues()) - assert(TypeInfo.inferScalarType(VPV) == TypeInfo2.inferScalarType(VPV)); - } -#endif - } - - // Simplify (X && Y) || (X && !Y) -> X. - // TODO: Split up into simpler, modular combines: (X && Y) || (X && Z) into X - // && (Y || Z) and (X || !X) into true. This requires queuing newly created - // recipes to be visited during simplification. - VPValue *X, *Y, *X1, *Y1; - if (match(&R, - m_c_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)), - m_LogicalAnd(m_VPValue(X1), m_Not(m_VPValue(Y1))))) && - X == X1 && Y == Y1) { - R.getVPSingleValue()->replaceAllUsesWith(X); - R.eraseFromParent(); - return; - } - - if (match(&R, m_c_Mul(m_VPValue(A), m_SpecificInt(1)))) - return R.getVPSingleValue()->replaceAllUsesWith(A); - - if (match(&R, m_Not(m_Not(m_VPValue(A))))) - return R.getVPSingleValue()->replaceAllUsesWith(A); - - // Remove redundant DerviedIVs, that is 0 + A * 1 -> A and 0 + 0 * x -> 0. - if ((match(&R, - m_DerivedIV(m_SpecificInt(0), m_VPValue(A), m_SpecificInt(1))) || - match(&R, - m_DerivedIV(m_SpecificInt(0), m_SpecificInt(0), m_VPValue()))) && - TypeInfo.inferScalarType(R.getOperand(1)) == - TypeInfo.inferScalarType(R.getVPSingleValue())) - return R.getVPSingleValue()->replaceAllUsesWith(R.getOperand(1)); -} - /// Move loop-invariant recipes out of the vector loop region in \p Plan. static void licm(VPlan &Plan) { VPBasicBlock *Preheader = Plan.getVectorPreheader(); @@ -1106,19 +1180,6 @@ static void licm(VPlan &Plan) { } } -/// Try to simplify the recipes in \p Plan. -static void simplifyRecipes(VPlan &Plan) { - ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( - Plan.getEntry()); - Type *CanonicalIVType = Plan.getCanonicalIV()->getScalarType(); - VPTypeAnalysis TypeInfo(CanonicalIVType); - for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) { - for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { - simplifyRecipe(R, TypeInfo); - } - } -} - void VPlanTransforms::truncateToMinimalBitwidths( VPlan &Plan, const MapVector<Instruction *, uint64_t> &MinBWs) { #ifndef NDEBUG @@ -1256,10 +1317,10 @@ void VPlanTransforms::optimize(VPlan &Plan) { removeRedundantCanonicalIVs(Plan); removeRedundantInductionCasts(Plan); - simplifyRecipes(Plan); + simplifyRecipes(Plan, Plan.getCanonicalIV()->getScalarType()); legalizeAndOptimizeInductions(Plan); removeRedundantExpandSCEVRecipes(Plan); - simplifyRecipes(Plan); + simplifyRecipes(Plan, Plan.getCanonicalIV()->getScalarType()); removeDeadRecipes(Plan); createAndOptimizeReplicateRegions(Plan); @@ -1496,10 +1557,13 @@ static VPRecipeBase *createEVLRecipe(VPValue *HeaderMask, auto *CastR = cast<VPWidenCastRecipe>(CR); VPID = VPIntrinsic::getForOpcode(CastR->getOpcode()); } - assert(VPID != Intrinsic::not_intrinsic && "Expected VP intrinsic"); + + // Not all intrinsics have a corresponding VP intrinsic. + if (VPID == Intrinsic::not_intrinsic) + return nullptr; assert(VPIntrinsic::getMaskParamPos(VPID) && VPIntrinsic::getVectorLengthParamPos(VPID) && - "Expected VP intrinsic"); + "Expected VP intrinsic to have mask and EVL"); SmallVector<VPValue *> Ops(CR->operands()); Ops.push_back(&AllOneMask); @@ -1656,9 +1720,9 @@ bool VPlanTransforms::tryAddExplicitVectorLength( VPSingleDefRecipe *OpVPEVL = VPEVL; if (unsigned IVSize = CanonicalIVPHI->getScalarType()->getScalarSizeInBits(); IVSize != 32) { - OpVPEVL = new VPScalarCastRecipe(IVSize < 32 ? Instruction::Trunc - : Instruction::ZExt, - OpVPEVL, CanonicalIVPHI->getScalarType()); + OpVPEVL = new VPScalarCastRecipe( + IVSize < 32 ? Instruction::Trunc : Instruction::ZExt, OpVPEVL, + CanonicalIVPHI->getScalarType(), CanonicalIVIncrement->getDebugLoc()); OpVPEVL->insertBefore(CanonicalIVIncrement); } auto *NextEVLIV = @@ -1898,7 +1962,7 @@ void VPlanTransforms::handleUncountableEarlyExit( if (OrigLoop->getUniqueExitBlock()) { VPEarlyExitBlock = cast<VPIRBasicBlock>(MiddleVPBB->getSuccessors()[0]); } else { - VPEarlyExitBlock = VPIRBasicBlock::fromBasicBlock( + VPEarlyExitBlock = Plan.createVPIRBasicBlock( !OrigLoop->contains(TrueSucc) ? TrueSucc : FalseSucc); } @@ -1908,7 +1972,7 @@ void VPlanTransforms::handleUncountableEarlyExit( IsEarlyExitTaken = Builder.createNaryOp(VPInstruction::AnyOf, {EarlyExitTakenCond}); - VPBasicBlock *NewMiddle = new VPBasicBlock("middle.split"); + VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split"); VPBlockUtils::insertOnEdge(LoopRegion, MiddleVPBB, NewMiddle); VPBlockUtils::connectBlocks(NewMiddle, VPEarlyExitBlock); NewMiddle->swapSuccessors(); diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.h b/llvm/lib/Transforms/Vectorize/VPlanUtils.h index 9657770..7779442 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.h +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.h @@ -49,7 +49,8 @@ inline bool isUniformAfterVectorization(const VPValue *VPV) { return all_of(GEP->operands(), isUniformAfterVectorization); if (auto *VPI = dyn_cast<VPInstruction>(Def)) return VPI->isSingleScalar() || VPI->isVectorToScalar(); - return false; + // VPExpandSCEVRecipes must be placed in the entry and are alway uniform. + return isa<VPExpandSCEVRecipe>(Def); } /// Return true if \p V is a header mask in \p Plan. diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index ecbc13d..1a669b5 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -128,6 +128,8 @@ private: bool shrinkType(Instruction &I); void replaceValue(Value &Old, Value &New) { + LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n'); + LLVM_DEBUG(dbgs() << " With: " << New << '\n'); Old.replaceAllUsesWith(&New); if (auto *NewI = dyn_cast<Instruction>(&New)) { New.takeName(&Old); @@ -139,10 +141,17 @@ private: void eraseInstruction(Instruction &I) { LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n'); - for (Value *Op : I.operands()) - Worklist.pushValue(Op); + SmallVector<Value *> Ops(I.operands()); Worklist.remove(&I); I.eraseFromParent(); + + // Push remaining users of the operands and then the operand itself - allows + // further folds that were hindered by OneUse limits. + for (Value *Op : Ops) + if (auto *OpI = dyn_cast<Instruction>(Op)) { + Worklist.pushUsersToWorkList(*OpI); + Worklist.pushValue(OpI); + } } }; } // namespace @@ -696,7 +705,8 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) { InstructionCost NewCost = TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) + - TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, CostKind); + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, VecTy, Mask, + CostKind); bool NeedLenChg = SrcVecTy->getNumElements() != NumElts; // If the lengths of the two vectors are not equal, @@ -1335,6 +1345,10 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { MemoryLocation::get(SI), AA)) return false; + // Ensure we add the load back to the worklist BEFORE its users so they can + // erased in the correct order. + Worklist.push(Load); + if (ScalarizableIdx.isSafeWithFreeze()) ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx)); Value *GEP = Builder.CreateInBoundsGEP( @@ -1360,8 +1374,8 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { if (!match(&I, m_Load(m_Value(Ptr)))) return false; - auto *VecTy = cast<VectorType>(I.getType()); auto *LI = cast<LoadInst>(&I); + auto *VecTy = cast<VectorType>(LI->getType()); if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType())) return false; @@ -1401,7 +1415,8 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { LastCheckedInst = UI; } - auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT); + auto ScalarIdx = + canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT); if (ScalarIdx.isUnsafe()) return false; if (ScalarIdx.isSafeWithFreeze()) { @@ -1409,7 +1424,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { ScalarIdx.discard(); } - auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); + auto *Index = dyn_cast<ConstantInt>(UI->getIndexOperand()); OriginalCost += TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, Index ? Index->getZExtValue() : -1); @@ -1422,10 +1437,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { if (ScalarizedCost >= OriginalCost) return false; + // Ensure we add the load back to the worklist BEFORE its users so they can + // erased in the correct order. + Worklist.push(LI); + // Replace extracts with narrow scalar loads. for (User *U : LI->users()) { auto *EI = cast<ExtractElementInst>(U); - Value *Idx = EI->getOperand(1); + Value *Idx = EI->getIndexOperand(); // Insert 'freeze' for poison indexes. auto It = NeedFreeze.find(EI); @@ -1669,7 +1688,8 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) { Value *X, *Y, *Z, *W; bool IsCommutative = false; - CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; + CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE; + CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE; if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) && match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) { auto *BO = cast<BinaryOperator>(LHS); @@ -1677,8 +1697,9 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) { if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem()) return false; IsCommutative = BinaryOperator::isCommutative(BO->getOpcode()); - } else if (match(LHS, m_Cmp(Pred, m_Value(X), m_Value(Y))) && - match(RHS, m_SpecificCmp(Pred, m_Value(Z), m_Value(W)))) { + } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) && + match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) && + (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) { IsCommutative = cast<CmpInst>(LHS)->isCommutative(); } else return false; @@ -1723,18 +1744,48 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) { TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinResTy, OldMask, CostKind, 0, nullptr, {LHS, RHS}, &I); + // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns + // where one use shuffles have gotten split across the binop/cmp. These + // often allow a major reduction in total cost that wouldn't happen as + // individual folds. + auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask, + TTI::TargetCostKind CostKind) -> bool { + Value *InnerOp; + ArrayRef<int> InnerMask; + if (match(Op, m_OneUse(m_Shuffle(m_Value(InnerOp), m_Undef(), + m_Mask(InnerMask)))) && + InnerOp->getType() == Op->getType() && + all_of(InnerMask, + [NumSrcElts](int M) { return M < (int)NumSrcElts; })) { + for (int &M : Mask) + if (Offset <= M && M < (int)(Offset + NumSrcElts)) { + M = InnerMask[M - Offset]; + M = 0 <= M ? M + Offset : M; + } + OldCost += TTI.getInstructionCost(cast<Instruction>(Op), CostKind); + Op = InnerOp; + return true; + } + return false; + }; + bool ReducedInstCount = false; + ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind); + ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind); + ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind); + ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind); + InstructionCost NewCost = TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) + TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W}); - if (Pred == CmpInst::BAD_ICMP_PREDICATE) { + if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) { NewCost += TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind); } else { auto *ShuffleCmpTy = FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy); NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, - ShuffleDstTy, Pred, CostKind); + ShuffleDstTy, PredLHS, CostKind); } LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I @@ -1743,17 +1794,17 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) { // If either shuffle will constant fold away, then fold for the same cost as // we will reduce the instruction count. - bool ReducedInstCount = (isa<Constant>(X) && isa<Constant>(Z)) || - (isa<Constant>(Y) && isa<Constant>(W)); + ReducedInstCount |= (isa<Constant>(X) && isa<Constant>(Z)) || + (isa<Constant>(Y) && isa<Constant>(W)); if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost)) return false; Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0); Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1); - Value *NewBO = Pred == CmpInst::BAD_ICMP_PREDICATE + Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE ? Builder.CreateBinOp( cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1) - : Builder.CreateCmp(Pred, Shuf0, Shuf1); + : Builder.CreateCmp(PredLHS, Shuf0, Shuf1); // Intersect flags from the old binops. if (auto *NewInst = dyn_cast<Instruction>(NewBO)) { @@ -1972,9 +2023,7 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { if (Match1) InnerCost1 = TTI.getInstructionCost(cast<Instruction>(OuterV1), CostKind); - InstructionCost OuterCost = TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind, - 0, nullptr, {OuterV0, OuterV1}, &I); + InstructionCost OuterCost = TTI.getInstructionCost(&I, CostKind); InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost; @@ -3047,12 +3096,16 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) { TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx); InstructionCost OldCost = ExtCost + InsCost; - InstructionCost NewCost = TTI.getShuffleCost(SK, VecTy, Mask, CostKind, 0, - nullptr, {DstVec, SrcVec}); + // Ignore 'free' identity insertion shuffle. + // TODO: getShuffleCost should return TCC_Free for Identity shuffles. + InstructionCost NewCost = 0; + if (!ShuffleVectorInst::isIdentityMask(Mask, NumElts)) + NewCost += TTI.getShuffleCost(SK, VecTy, Mask, CostKind, 0, nullptr, + {DstVec, SrcVec}); if (!Ext->hasOneUse()) NewCost += ExtCost; - LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair : " << I + LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost << "\n"); |