diff options
Diffstat (limited to 'llvm/lib')
58 files changed, 2073 insertions, 884 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 45c889c..a5ba197 100755 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -2177,16 +2177,13 @@ Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) { return PoisonValue::get(VT->getElementType()); // TODO: Handle undef. - if (!isa<ConstantVector>(Op) && !isa<ConstantDataVector>(Op)) - return nullptr; - - auto *EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(0U)); + auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U)); if (!EltC) return nullptr; APInt Acc = EltC->getValue(); for (unsigned I = 1, E = VT->getNumElements(); I != E; I++) { - if (!(EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(I)))) + if (!(EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(I)))) return nullptr; const APInt &X = EltC->getValue(); switch (IID) { @@ -3059,35 +3056,25 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, Val = Val | Val << 1; return ConstantInt::get(Ty, Val); } - - default: - return nullptr; } } - switch (IntrinsicID) { - default: break; - case Intrinsic::vector_reduce_add: - case Intrinsic::vector_reduce_mul: - case Intrinsic::vector_reduce_and: - case Intrinsic::vector_reduce_or: - case Intrinsic::vector_reduce_xor: - case Intrinsic::vector_reduce_smin: - case Intrinsic::vector_reduce_smax: - case Intrinsic::vector_reduce_umin: - case Intrinsic::vector_reduce_umax: - if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0])) - return C; - break; - } - - // Support ConstantVector in case we have an Undef in the top. - if (isa<ConstantVector>(Operands[0]) || - isa<ConstantDataVector>(Operands[0]) || - isa<ConstantAggregateZero>(Operands[0])) { + if (Operands[0]->getType()->isVectorTy()) { auto *Op = cast<Constant>(Operands[0]); switch (IntrinsicID) { default: break; + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_umax: + if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0])) + return C; + break; case Intrinsic::x86_sse_cvtss2si: case Intrinsic::x86_sse_cvtss2si64: case Intrinsic::x86_sse2_cvtsd2si: @@ -3116,10 +3103,15 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, case Intrinsic::wasm_alltrue: // Check each element individually unsigned E = cast<FixedVectorType>(Op->getType())->getNumElements(); - for (unsigned I = 0; I != E; ++I) - if (Constant *Elt = Op->getAggregateElement(I)) - if (Elt->isZeroValue()) - return ConstantInt::get(Ty, 0); + for (unsigned I = 0; I != E; ++I) { + Constant *Elt = Op->getAggregateElement(I); + // Return false as soon as we find a non-true element. + if (Elt && Elt->isZeroValue()) + return ConstantInt::get(Ty, 0); + // Bail as soon as we find an element we cannot prove to be true. + if (!Elt || !isa<ConstantInt>(Elt)) + return nullptr; + } return ConstantInt::get(Ty, 1); } diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp index b78cc03e..f9bf092 100644 --- a/llvm/lib/Analysis/DXILResource.cpp +++ b/llvm/lib/Analysis/DXILResource.cpp @@ -281,6 +281,38 @@ static StructType *getOrCreateElementStruct(Type *ElemType, StringRef Name) { return StructType::create(ElemType, Name); } +static Type *getTypeWithoutPadding(Type *Ty) { + // Recursively remove padding from structures. + if (auto *ST = dyn_cast<StructType>(Ty)) { + LLVMContext &Ctx = Ty->getContext(); + SmallVector<Type *> ElementTypes; + ElementTypes.reserve(ST->getNumElements()); + for (Type *ElTy : ST->elements()) { + if (isa<PaddingExtType>(ElTy)) + continue; + ElementTypes.push_back(getTypeWithoutPadding(ElTy)); + } + + // Handle explicitly padded cbuffer arrays like { [ n x paddedty ], ty } + if (ElementTypes.size() == 2) + if (auto *AT = dyn_cast<ArrayType>(ElementTypes[0])) + if (ElementTypes[1] == AT->getElementType()) + return ArrayType::get(ElementTypes[1], AT->getNumElements() + 1); + + // If we only have a single element, don't wrap it in a struct. + if (ElementTypes.size() == 1) + return ElementTypes[0]; + + return StructType::get(Ctx, ElementTypes, /*IsPacked=*/false); + } + // Arrays just need to have their element type adjusted. + if (auto *AT = dyn_cast<ArrayType>(Ty)) + return ArrayType::get(getTypeWithoutPadding(AT->getElementType()), + AT->getNumElements()); + // Anything else should be good as is. + return Ty; +} + StructType *ResourceTypeInfo::createElementStruct(StringRef CBufferName) { SmallString<64> TypeName; @@ -334,14 +366,21 @@ StructType *ResourceTypeInfo::createElementStruct(StringRef CBufferName) { } case ResourceKind::CBuffer: { auto *RTy = cast<CBufferExtType>(HandleTy); - LayoutExtType *LayoutType = cast<LayoutExtType>(RTy->getResourceType()); - StructType *Ty = cast<StructType>(LayoutType->getWrappedType()); SmallString<64> Name = getResourceKindName(Kind); if (!CBufferName.empty()) { Name.append("."); Name.append(CBufferName); } - return StructType::create(Ty->elements(), Name); + + // TODO: Remove this when we update the frontend to use explicit padding. + if (LayoutExtType *LayoutType = + dyn_cast<LayoutExtType>(RTy->getResourceType())) { + StructType *Ty = cast<StructType>(LayoutType->getWrappedType()); + return StructType::create(Ty->elements(), Name); + } + + return getOrCreateElementStruct( + getTypeWithoutPadding(RTy->getResourceType()), Name); } case ResourceKind::Sampler: { auto *RTy = cast<SamplerExtType>(HandleTy); @@ -454,10 +493,10 @@ uint32_t ResourceTypeInfo::getCBufferSize(const DataLayout &DL) const { Type *ElTy = cast<CBufferExtType>(HandleTy)->getResourceType(); + // TODO: Remove this when we update the frontend to use explicit padding. if (auto *LayoutTy = dyn_cast<LayoutExtType>(ElTy)) return LayoutTy->getSize(); - // TODO: What should we do with unannotated arrays? return DL.getTypeAllocSize(ElTy); } diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp index 0e5bc48..df75999 100644 --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -947,9 +947,8 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) { /*UseBlockValue*/ false)); } - ValueLatticeElement Result = TrueVal; - Result.mergeIn(FalseVal); - return Result; + TrueVal.mergeIn(FalseVal); + return TrueVal; } std::optional<ConstantRange> @@ -1778,9 +1777,8 @@ ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB, assert(OptResult && "Value not available after solving"); } - ValueLatticeElement Result = *OptResult; - LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); - return Result; + LLVM_DEBUG(dbgs() << " Result = " << *OptResult << "\n"); + return *OptResult; } ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) { diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp index a8c3173..d84721b 100644 --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -986,8 +986,8 @@ PreservedAnalyses LoopPrinterPass::run(Function &F, return PreservedAnalyses::all(); } -void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) { - +void llvm::printLoop(const Loop &L, raw_ostream &OS, + const std::string &Banner) { if (forcePrintModuleIR()) { // handling -print-module-scope OS << Banner << " (loop: "; diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index a64b93d..425420f 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -1840,19 +1840,19 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>. // - if (SM->getNumOperands() == 2) - if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0))) - if (MulLHS->getAPInt().isPowerOf2()) - if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) { - int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) - - MulLHS->getAPInt().logBase2(); - Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); - return getMulExpr( - getZeroExtendExpr(MulLHS, Ty), - getZeroExtendExpr( - getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty), - SCEV::FlagNUW, Depth + 1); - } + const APInt *C; + const SCEV *TruncRHS; + if (match(SM, + m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) && + C->isPowerOf2()) { + int NewTruncBits = + getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2(); + Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); + return getMulExpr( + getZeroExtendExpr(SM->getOperand(0), Ty), + getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty), + SCEV::FlagNUW, Depth + 1); + } } // zext(umin(x, y)) -> umin(zext(x), zext(y)) @@ -3144,20 +3144,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { if (Ops.size() == 2) { // C1*(C2+V) -> C1*C2 + C1*V - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) - // If any of Add's ops are Adds or Muls with a constant, apply this - // transformation as well. - // - // TODO: There are some cases where this transformation is not - // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of - // this transformation should be narrowed down. - if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) { - const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0), - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1), - SCEV::FlagAnyWrap, Depth + 1); - return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); - } + // If any of Add's ops are Adds or Muls with a constant, apply this + // transformation as well. + // + // TODO: There are some cases where this transformation is not + // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of + // this transformation should be narrowed down. + const SCEV *Op0, *Op1; + if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) && + containsConstantInAddMulChain(Ops[1])) { + const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1); + const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1); + return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); + } if (Ops[0]->isAllOnesValue()) { // If we have a mul by -1 of an add, try distributing the -1 among the @@ -3578,20 +3577,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C. - if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS); - AE && AE->getNumOperands() == 2) { - if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) { - const APInt &NegC = VC->getAPInt(); - if (NegC.isNegative() && !NegC.isMinSignedValue()) { - const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1)); - if (MME && MME->getNumOperands() == 2 && - isa<SCEVConstant>(MME->getOperand(0)) && - cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC && - MME->getOperand(1) == RHS) - return getZero(LHS->getType()); - } - } - } + const APInt *NegC, *C; + if (match(LHS, + m_scev_Add(m_scev_APInt(NegC), + m_scev_SMax(m_scev_APInt(C), m_scev_Specific(RHS)))) && + NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC) + return getZero(LHS->getType()); // TODO: Generalize to handle any common factors. // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b @@ -4623,17 +4614,11 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, /// If Expr computes ~A, return A else return nullptr static const SCEV *MatchNotExpr(const SCEV *Expr) { - const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); - if (!Add || Add->getNumOperands() != 2 || - !Add->getOperand(0)->isAllOnesValue()) - return nullptr; - - const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); - if (!AddRHS || AddRHS->getNumOperands() != 2 || - !AddRHS->getOperand(0)->isAllOnesValue()) - return nullptr; - - return AddRHS->getOperand(1); + const SCEV *MulOp; + if (match(Expr, m_scev_Add(m_scev_AllOnes(), + m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp))))) + return MulOp; + return nullptr; } /// Return a SCEV corresponding to ~V = -1-V @@ -10797,19 +10782,15 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { } static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) { - const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S); - if (!Add || Add->getNumOperands() != 2) + const SCEV *Op0, *Op1; + if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1)))) return false; - if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); - ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { - LHS = Add->getOperand(1); - RHS = ME->getOperand(1); + if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) { + LHS = Op1; return true; } - if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); - ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { - LHS = Add->getOperand(0); - RHS = ME->getOperand(1); + if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) { + LHS = Op0; return true; } return false; @@ -12172,13 +12153,10 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags) { - const auto *AE = dyn_cast<SCEVAddExpr>(Expr); - if (!AE || AE->getNumOperands() != 2) + if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R)))) return false; - L = AE->getOperand(0); - R = AE->getOperand(1); - Flags = AE->getNoWrapFlags(); + Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags(); return true; } @@ -12220,12 +12198,11 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { // Try to match a common constant multiply. auto MatchConstMul = [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> { - auto *M = dyn_cast<SCEVMulExpr>(S); - if (!M || M->getNumOperands() != 2 || - !isa<SCEVConstant>(M->getOperand(0))) - return std::nullopt; - return { - {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}}; + const APInt *C; + const SCEV *Op; + if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op)))) + return {{Op, *C}}; + return std::nullopt; }; if (auto MatchedMore = MatchConstMul(More)) { if (auto MatchedLess = MatchConstMul(Less)) { @@ -15557,19 +15534,10 @@ void ScalarEvolution::LoopGuards::collectFromBlock( auto IsMinMaxSCEVWithNonNegativeConstant = [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, const SCEV *&RHS) { - if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) { - if (MinMax->getNumOperands() != 2) - return false; - if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) { - if (C->getAPInt().isNegative()) - return false; - SCTy = MinMax->getSCEVType(); - LHS = MinMax->getOperand(0); - RHS = MinMax->getOperand(1); - return true; - } - } - return false; + const APInt *C; + SCTy = Expr->getSCEVType(); + return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) && + match(LHS, m_scev_APInt(C)) && C->isNonNegative(); }; // Return a new SCEV that modifies \p Expr to the closest number divides by @@ -15772,19 +15740,26 @@ void ScalarEvolution::LoopGuards::collectFromBlock( GetNextSCEVDividesByDivisor(One, DividesBy); To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } else { + // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS), + // but creating the subtraction eagerly is expensive. Track the + // inequalities in a separate map, and materialize the rewrite lazily + // when encountering a suitable subtraction while re-writing. if (LHS->getType()->isPointerTy()) { LHS = SE.getLosslessPtrToIntExpr(LHS); RHS = SE.getLosslessPtrToIntExpr(RHS); if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS)) break; } - auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) { - const SCEV *Sub = SE.getMinusSCEV(A, B); - AddRewrite(Sub, Sub, - SE.getUMaxExpr(Sub, SE.getOne(From->getType()))); - }; - AddSubRewrite(LHS, RHS); - AddSubRewrite(RHS, LHS); + const SCEVConstant *C; + const SCEV *A, *B; + if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) && + match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) { + RHS = A; + LHS = B; + } + if (LHS > RHS) + std::swap(LHS, RHS); + Guards.NotEqual.insert({LHS, RHS}); continue; } break; @@ -15918,13 +15893,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { const DenseMap<const SCEV *, const SCEV *> ⤅ + const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> ≠ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; public: SCEVLoopGuardRewriter(ScalarEvolution &SE, const ScalarEvolution::LoopGuards &Guards) - : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) { + : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap), + NotEqual(Guards.NotEqual) { if (Guards.PreserveNUW) FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW); if (Guards.PreserveNSW) @@ -15979,14 +15956,36 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + // Helper to check if S is a subtraction (A - B) where A != B, and if so, + // return UMax(S, 1). + auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * { + const SCEV *LHS, *RHS; + if (MatchBinarySub(S, LHS, RHS)) { + if (LHS > RHS) + std::swap(LHS, RHS); + if (NotEqual.contains({LHS, RHS})) + return SE.getUMaxExpr(S, SE.getOne(S->getType())); + } + return nullptr; + }; + + // Check if Expr itself is a subtraction pattern with guard info. + if (const SCEV *Rewritten = RewriteSubtraction(Expr)) + return Rewritten; + // Trip count expressions sometimes consist of adding 3 operands, i.e. // (Const + A + B). There may be guard info for A + B, and if so, apply // it. // TODO: Could more generally apply guards to Add sub-expressions. if (isa<SCEVConstant>(Expr->getOperand(0)) && Expr->getNumOperands() == 3) { - if (const SCEV *S = Map.lookup( - SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)))) + const SCEV *Add = + SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)); + if (const SCEV *Rewritten = RewriteSubtraction(Add)) + return SE.getAddExpr( + Expr->getOperand(0), Rewritten, + ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask)); + if (const SCEV *S = Map.lookup(Add)) return SE.getAddExpr(Expr->getOperand(0), S); } SmallVector<const SCEV *, 2> Operands; @@ -16021,7 +16020,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } }; - if (RewriteMap.empty()) + if (RewriteMap.empty() && NotEqual.empty()) return Expr; SCEVLoopGuardRewriter Rewriter(SE, *this); diff --git a/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp b/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp index 433877f..567acf7 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp @@ -1039,12 +1039,17 @@ void DwarfDebug::finishUnitAttributes(const DICompileUnit *DIUnit, } else NewCU.addString(Die, dwarf::DW_AT_producer, Producer); - if (auto Lang = DIUnit->getSourceLanguage(); Lang.hasVersionedName()) + if (auto Lang = DIUnit->getSourceLanguage(); Lang.hasVersionedName()) { NewCU.addUInt(Die, dwarf::DW_AT_language_name, dwarf::DW_FORM_data2, Lang.getName()); - else + + if (uint32_t LangVersion = Lang.getVersion(); LangVersion != 0) + NewCU.addUInt(Die, dwarf::DW_AT_language_version, /*Form=*/std::nullopt, + LangVersion); + } else { NewCU.addUInt(Die, dwarf::DW_AT_language, dwarf::DW_FORM_data2, Lang.getName()); + } NewCU.addString(Die, dwarf::DW_AT_name, FN); StringRef SysRoot = DIUnit->getSysRoot(); @@ -2066,11 +2071,36 @@ void DwarfDebug::beginInstruction(const MachineInstr *MI) { if (NoDebug) return; + auto RecordLineZero = [&]() { + // Preserve the file and column numbers, if we can, to save space in + // the encoded line table. + // Do not update PrevInstLoc, it remembers the last non-0 line. + const MDNode *Scope = nullptr; + unsigned Column = 0; + if (PrevInstLoc) { + Scope = PrevInstLoc.getScope(); + Column = PrevInstLoc.getCol(); + } + recordSourceLine(/*Line=*/0, Column, Scope, /*Flags=*/0); + }; + + // When we emit a line-0 record, we don't update PrevInstLoc; so look at + // the last line number actually emitted, to see if it was line 0. + unsigned LastAsmLine = + Asm->OutStreamer->getContext().getCurrentDwarfLoc().getLine(); + // Check if source location changes, but ignore DBG_VALUE and CFI locations. // If the instruction is part of the function frame setup code, do not emit // any line record, as there is no correspondence with any user code. - if (MI->isMetaInstruction() || MI->getFlag(MachineInstr::FrameSetup)) + if (MI->isMetaInstruction()) + return; + if (MI->getFlag(MachineInstr::FrameSetup)) { + // Prevent a loc from the previous block leaking into frame setup instrs. + if (LastAsmLine && PrevInstBB && PrevInstBB != MI->getParent()) + RecordLineZero(); return; + } + const DebugLoc &DL = MI->getDebugLoc(); unsigned Flags = 0; @@ -2093,11 +2123,6 @@ void DwarfDebug::beginInstruction(const MachineInstr *MI) { LocationString); }; - // When we emit a line-0 record, we don't update PrevInstLoc; so look at - // the last line number actually emitted, to see if it was line 0. - unsigned LastAsmLine = - Asm->OutStreamer->getContext().getCurrentDwarfLoc().getLine(); - // There may be a mixture of scopes using and not using Key Instructions. // Not-Key-Instructions functions inlined into Key Instructions functions // should use not-key is_stmt handling. Key Instructions functions inlined @@ -2163,18 +2188,8 @@ void DwarfDebug::beginInstruction(const MachineInstr *MI) { // - Instruction is at the top of a block; we don't want to inherit the // location from the physically previous (maybe unrelated) block. if (UnknownLocations == Enable || PrevLabel || - (PrevInstBB && PrevInstBB != MI->getParent())) { - // Preserve the file and column numbers, if we can, to save space in - // the encoded line table. - // Do not update PrevInstLoc, it remembers the last non-0 line. - const MDNode *Scope = nullptr; - unsigned Column = 0; - if (PrevInstLoc) { - Scope = PrevInstLoc.getScope(); - Column = PrevInstLoc.getCol(); - } - recordSourceLine(/*Line=*/0, Column, Scope, /*Flags=*/0); - } + (PrevInstBB && PrevInstBB != MI->getParent())) + RecordLineZero(); return; } diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index f28b989..d8374b6 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -6041,8 +6041,7 @@ std::string llvm::UpgradeDataLayoutString(StringRef DL, StringRef TT) { Triple T(TT); // The only data layout upgrades needed for pre-GCN, SPIR or SPIRV are setting // the address space of globals to 1. This does not apply to SPIRV Logical. - if (((T.isAMDGPU() && !T.isAMDGCN()) || - (T.isSPIR() || (T.isSPIRV() && !T.isSPIRVLogical()))) && + if ((T.isSPIR() || (T.isSPIRV() && !T.isSPIRVLogical())) && !DL.contains("-G") && !DL.starts_with("G")) { return DL.empty() ? std::string("G1") : (DL + "-G1").str(); } @@ -6055,35 +6054,43 @@ std::string llvm::UpgradeDataLayoutString(StringRef DL, StringRef TT) { return DL.str(); } + // AMDGPU data layout upgrades. std::string Res = DL.str(); - // AMDGCN data layout upgrades. - if (T.isAMDGCN()) { + if (T.isAMDGPU()) { // Define address spaces for constants. if (!DL.contains("-G") && !DL.starts_with("G")) Res.append(Res.empty() ? "G1" : "-G1"); - // Add missing non-integral declarations. - // This goes before adding new address spaces to prevent incoherent string - // values. - if (!DL.contains("-ni") && !DL.starts_with("ni")) - Res.append("-ni:7:8:9"); - // Update ni:7 to ni:7:8:9. - if (DL.ends_with("ni:7")) - Res.append(":8:9"); - if (DL.ends_with("ni:7:8")) - Res.append(":9"); - - // Add sizing for address spaces 7 and 8 (fat raw buffers and buffer - // resources) An empty data layout has already been upgraded to G1 by now. - if (!DL.contains("-p7") && !DL.starts_with("p7")) - Res.append("-p7:160:256:256:32"); - if (!DL.contains("-p8") && !DL.starts_with("p8")) - Res.append("-p8:128:128:128:48"); - constexpr StringRef OldP8("-p8:128:128-"); - if (DL.contains(OldP8)) - Res.replace(Res.find(OldP8), OldP8.size(), "-p8:128:128:128:48-"); - if (!DL.contains("-p9") && !DL.starts_with("p9")) - Res.append("-p9:192:256:256:32"); + // AMDGCN data layout upgrades. + if (T.isAMDGCN()) { + + // Add missing non-integral declarations. + // This goes before adding new address spaces to prevent incoherent string + // values. + if (!DL.contains("-ni") && !DL.starts_with("ni")) + Res.append("-ni:7:8:9"); + // Update ni:7 to ni:7:8:9. + if (DL.ends_with("ni:7")) + Res.append(":8:9"); + if (DL.ends_with("ni:7:8")) + Res.append(":9"); + + // Add sizing for address spaces 7 and 8 (fat raw buffers and buffer + // resources) An empty data layout has already been upgraded to G1 by now. + if (!DL.contains("-p7") && !DL.starts_with("p7")) + Res.append("-p7:160:256:256:32"); + if (!DL.contains("-p8") && !DL.starts_with("p8")) + Res.append("-p8:128:128:128:48"); + constexpr StringRef OldP8("-p8:128:128-"); + if (DL.contains(OldP8)) + Res.replace(Res.find(OldP8), OldP8.size(), "-p8:128:128:128:48-"); + if (!DL.contains("-p9") && !DL.starts_with("p9")) + Res.append("-p9:192:256:256:32"); + } + + // Upgrade the ELF mangling mode. + if (!DL.contains("m:e")) + Res = Res.empty() ? "m:e" : "m:e-" + Res; return Res; } diff --git a/llvm/lib/IR/DebugInfo.cpp b/llvm/lib/IR/DebugInfo.cpp index 9601a8a..5883606 100644 --- a/llvm/lib/IR/DebugInfo.cpp +++ b/llvm/lib/IR/DebugInfo.cpp @@ -294,9 +294,9 @@ void DebugInfoFinder::processSubprogram(DISubprogram *SP) { // just DISubprogram's, referenced from anywhere within the Function being // cloned prior to calling MapMetadata / RemapInstruction to avoid their // duplication later as DICompileUnit's are also directly referenced by - // llvm.dbg.cu list. Thefore we need to collect DICompileUnit's here as well. - // Also, DICompileUnit's may reference DISubprogram's too and therefore need - // to be at least looked through. + // llvm.dbg.cu list. Therefore we need to collect DICompileUnit's here as + // well. Also, DICompileUnit's may reference DISubprogram's too and therefore + // need to be at least looked through. processCompileUnit(SP->getUnit()); processType(SP->getType()); for (auto *Element : SP->getTemplateParams()) { @@ -377,7 +377,7 @@ bool DebugInfoFinder::addScope(DIScope *Scope) { /// Recursively handle DILocations in followup metadata etc. /// -/// TODO: If for example a followup loop metadata would refence itself this +/// TODO: If for example a followup loop metadata would reference itself this /// function would go into infinite recursion. We do not expect such cycles in /// the loop metadata (except for the self-referencing first element /// "LoopID"). However, we could at least handle such situations more gracefully @@ -679,7 +679,7 @@ private: auto Variables = nullptr; auto TemplateParams = nullptr; - // Make a distinct DISubprogram, for situations that warrent it. + // Make a distinct DISubprogram, for situations that warrant it. auto distinctMDSubprogram = [&]() { return DISubprogram::getDistinct( MDS->getContext(), FileAndScope, MDS->getName(), LinkageName, @@ -1095,6 +1095,35 @@ LLVMDIBuilderCreateFile(LLVMDIBuilderRef Builder, const char *Filename, StringRef(Directory, DirectoryLen))); } +static llvm::DIFile::ChecksumKind +map_from_llvmChecksumKind(LLVMChecksumKind CSKind) { + switch (CSKind) { + case LLVMChecksumKind::CSK_MD5: + return llvm::DIFile::CSK_MD5; + case LLVMChecksumKind::CSK_SHA1: + return llvm::DIFile::CSK_SHA1; + case LLVMChecksumKind::CSK_SHA256: + return llvm::DIFile::CSK_SHA256; + } + llvm_unreachable("Unhandled Checksum Kind"); +} + +LLVMMetadataRef LLVMDIBuilderCreateFileWithChecksum( + LLVMDIBuilderRef Builder, const char *Filename, size_t FilenameLen, + const char *Directory, size_t DirectoryLen, LLVMChecksumKind ChecksumKind, + const char *Checksum, size_t ChecksumLen, const char *Source, + size_t SourceLen) { + StringRef ChkSum = StringRef(Checksum, ChecksumLen); + auto CSK = map_from_llvmChecksumKind(ChecksumKind); + llvm::DIFile::ChecksumInfo<StringRef> CSInfo(CSK, ChkSum); + std::optional<StringRef> Src; + if (SourceLen > 0) + Src = StringRef(Source, SourceLen); + return wrap(unwrap(Builder)->createFile(StringRef(Filename, FilenameLen), + StringRef(Directory, DirectoryLen), + CSInfo, Src)); +} + LLVMMetadataRef LLVMDIBuilderCreateModule(LLVMDIBuilderRef Builder, LLVMMetadataRef ParentScope, const char *Name, size_t NameLen, @@ -2014,7 +2043,7 @@ void at::remapAssignID(DenseMap<DIAssignID *, DIAssignID *> &Map, I.setMetadata(LLVMContext::MD_DIAssignID, GetNewID(ID)); } -/// Collect constant properies (base, size, offset) of \p StoreDest. +/// Collect constant properties (base, size, offset) of \p StoreDest. /// Return std::nullopt if any properties are not constants or the /// offset from the base pointer is negative. static std::optional<AssignmentInfo> @@ -2300,7 +2329,7 @@ PreservedAnalyses AssignmentTrackingPass::run(Function &F, return PreservedAnalyses::all(); // Record that this module uses assignment tracking. It doesn't matter that - // some functons in the module may not use it - the debug info in those + // some functions in the module may not use it - the debug info in those // functions will still be handled properly. setAssignmentTrackingModuleFlag(*F.getParent()); diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp index 9db48e8..0e9535d 100644 --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -1034,6 +1034,10 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) { } // DirectX resources + if (Name == "dx.Padding") + return TargetTypeInfo( + ArrayType::get(Type::getInt8Ty(C), Ty->getIntParameter(0)), + TargetExtType::CanBeGlobal); if (Name.starts_with("dx.")) return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::CanBeGlobal, TargetExtType::CanBeLocal, diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index c79a950..3572852 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -6479,9 +6479,12 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) { NumRows->getZExtValue() * NumColumns->getZExtValue(), "Result of a matrix operation does not fit in the returned vector!"); - if (Stride) + if (Stride) { + Check(Stride->getBitWidth() <= 64, "Stride bitwidth cannot exceed 64!", + IF); Check(Stride->getZExtValue() >= NumRows->getZExtValue(), "Stride must be greater or equal than the number of rows!", IF); + } break; } diff --git a/llvm/lib/LTO/LTO.cpp b/llvm/lib/LTO/LTO.cpp index e6544f3..aec8891 100644 --- a/llvm/lib/LTO/LTO.cpp +++ b/llvm/lib/LTO/LTO.cpp @@ -1257,38 +1257,6 @@ Error LTO::run(AddStreamFn AddStream, FileCache Cache) { return Result; } -void lto::updateMemProfAttributes(Module &Mod, - const ModuleSummaryIndex &Index) { - llvm::TimeTraceScope timeScope("LTO update memprof attributes"); - if (Index.withSupportsHotColdNew()) - return; - - // The profile matcher applies hotness attributes directly for allocations, - // and those will cause us to generate calls to the hot/cold interfaces - // unconditionally. If supports-hot-cold-new was not enabled in the LTO - // link then assume we don't want these calls (e.g. not linking with - // the appropriate library, or otherwise trying to disable this behavior). - for (auto &F : Mod) { - for (auto &BB : F) { - for (auto &I : BB) { - auto *CI = dyn_cast<CallBase>(&I); - if (!CI) - continue; - if (CI->hasFnAttr("memprof")) - CI->removeFnAttr("memprof"); - // Strip off all memprof metadata as it is no longer needed. - // Importantly, this avoids the addition of new memprof attributes - // after inlining propagation. - // TODO: If we support additional types of MemProf metadata beyond hot - // and cold, we will need to update the metadata based on the allocator - // APIs supported instead of completely stripping all. - CI->setMetadata(LLVMContext::MD_memprof, nullptr); - CI->setMetadata(LLVMContext::MD_callsite, nullptr); - } - } - } -} - Error LTO::runRegularLTO(AddStreamFn AddStream) { llvm::TimeTraceScope timeScope("Run regular LTO"); LLVMContext &CombinedCtx = RegularLTO.CombinedModule->getContext(); @@ -1346,8 +1314,6 @@ Error LTO::runRegularLTO(AddStreamFn AddStream) { } } - updateMemProfAttributes(*RegularLTO.CombinedModule, ThinLTO.CombinedIndex); - bool WholeProgramVisibilityEnabledInLTO = Conf.HasWholeProgramVisibility && // If validation is enabled, upgrade visibility only when all vtables diff --git a/llvm/lib/LTO/LTOBackend.cpp b/llvm/lib/LTO/LTOBackend.cpp index 11a7b32..280c3d1 100644 --- a/llvm/lib/LTO/LTOBackend.cpp +++ b/llvm/lib/LTO/LTOBackend.cpp @@ -726,7 +726,6 @@ Error lto::thinBackend(const Config &Conf, unsigned Task, AddStreamFn AddStream, } // Do this after any importing so that imported code is updated. - updateMemProfAttributes(Mod, CombinedIndex); updatePublicTypeTestCalls(Mod, CombinedIndex.withWholeProgramVisibility()); if (Conf.PostImportModuleHook && !Conf.PostImportModuleHook(Task, Mod)) diff --git a/llvm/lib/ObjectYAML/ELFYAML.cpp b/llvm/lib/ObjectYAML/ELFYAML.cpp index 421d6603..c3a27c9 100644 --- a/llvm/lib/ObjectYAML/ELFYAML.cpp +++ b/llvm/lib/ObjectYAML/ELFYAML.cpp @@ -488,6 +488,7 @@ void ScalarBitSetTraits<ELFYAML::ELF_EF>::bitset(IO &IO, BCaseMask(EF_HEXAGON_MACH_V5, EF_HEXAGON_MACH); BCaseMask(EF_HEXAGON_MACH_V55, EF_HEXAGON_MACH); BCaseMask(EF_HEXAGON_MACH_V60, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V61, EF_HEXAGON_MACH); BCaseMask(EF_HEXAGON_MACH_V62, EF_HEXAGON_MACH); BCaseMask(EF_HEXAGON_MACH_V65, EF_HEXAGON_MACH); BCaseMask(EF_HEXAGON_MACH_V66, EF_HEXAGON_MACH); @@ -499,12 +500,21 @@ void ScalarBitSetTraits<ELFYAML::ELF_EF>::bitset(IO &IO, BCaseMask(EF_HEXAGON_MACH_V71T, EF_HEXAGON_MACH); BCaseMask(EF_HEXAGON_MACH_V73, EF_HEXAGON_MACH); BCaseMask(EF_HEXAGON_MACH_V75, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V77, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V79, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V81, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V83, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V85, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V87, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V89, EF_HEXAGON_MACH); + BCaseMask(EF_HEXAGON_MACH_V91, EF_HEXAGON_MACH); BCaseMask(EF_HEXAGON_ISA_V2, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V3, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V4, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V5, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V55, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V60, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V61, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V62, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V65, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V66, EF_HEXAGON_ISA); @@ -514,6 +524,14 @@ void ScalarBitSetTraits<ELFYAML::ELF_EF>::bitset(IO &IO, BCaseMask(EF_HEXAGON_ISA_V71, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V73, EF_HEXAGON_ISA); BCaseMask(EF_HEXAGON_ISA_V75, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V77, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V79, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V81, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V83, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V85, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V87, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V89, EF_HEXAGON_ISA); + BCaseMask(EF_HEXAGON_ISA_V91, EF_HEXAGON_ISA); break; case ELF::EM_AVR: BCaseMask(EF_AVR_ARCH_AVR1, EF_AVR_ARCH_MASK); diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 53cf004..e45cac8 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -2027,13 +2027,13 @@ Error PassBuilder::parseModulePass(ModulePassManager &MPM, #define LOOPNEST_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ MPM.addPass(createModuleToFunctionPassAdaptor( \ - createFunctionToLoopPassAdaptor(CREATE_PASS, false, false))); \ + createFunctionToLoopPassAdaptor(CREATE_PASS, false))); \ return Error::success(); \ } #define LOOP_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ MPM.addPass(createModuleToFunctionPassAdaptor( \ - createFunctionToLoopPassAdaptor(CREATE_PASS, false, false))); \ + createFunctionToLoopPassAdaptor(CREATE_PASS, false))); \ return Error::success(); \ } #define LOOP_PASS_WITH_PARAMS(NAME, CLASS, CREATE_PASS, PARSER, PARAMS) \ @@ -2041,9 +2041,8 @@ Error PassBuilder::parseModulePass(ModulePassManager &MPM, auto Params = parsePassParameters(PARSER, Name, NAME); \ if (!Params) \ return Params.takeError(); \ - MPM.addPass( \ - createModuleToFunctionPassAdaptor(createFunctionToLoopPassAdaptor( \ - CREATE_PASS(Params.get()), false, false))); \ + MPM.addPass(createModuleToFunctionPassAdaptor( \ + createFunctionToLoopPassAdaptor(CREATE_PASS(Params.get()), false))); \ return Error::success(); \ } #include "PassRegistry.def" @@ -2142,13 +2141,13 @@ Error PassBuilder::parseCGSCCPass(CGSCCPassManager &CGPM, #define LOOPNEST_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ CGPM.addPass(createCGSCCToFunctionPassAdaptor( \ - createFunctionToLoopPassAdaptor(CREATE_PASS, false, false))); \ + createFunctionToLoopPassAdaptor(CREATE_PASS, false))); \ return Error::success(); \ } #define LOOP_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ CGPM.addPass(createCGSCCToFunctionPassAdaptor( \ - createFunctionToLoopPassAdaptor(CREATE_PASS, false, false))); \ + createFunctionToLoopPassAdaptor(CREATE_PASS, false))); \ return Error::success(); \ } #define LOOP_PASS_WITH_PARAMS(NAME, CLASS, CREATE_PASS, PARSER, PARAMS) \ @@ -2156,9 +2155,8 @@ Error PassBuilder::parseCGSCCPass(CGSCCPassManager &CGPM, auto Params = parsePassParameters(PARSER, Name, NAME); \ if (!Params) \ return Params.takeError(); \ - CGPM.addPass( \ - createCGSCCToFunctionPassAdaptor(createFunctionToLoopPassAdaptor( \ - CREATE_PASS(Params.get()), false, false))); \ + CGPM.addPass(createCGSCCToFunctionPassAdaptor( \ + createFunctionToLoopPassAdaptor(CREATE_PASS(Params.get()), false))); \ return Error::success(); \ } #include "PassRegistry.def" @@ -2191,11 +2189,8 @@ Error PassBuilder::parseFunctionPass(FunctionPassManager &FPM, return Err; // Add the nested pass manager with the appropriate adaptor. bool UseMemorySSA = (Name == "loop-mssa"); - bool UseBFI = llvm::any_of(InnerPipeline, [](auto Pipeline) { - return Pipeline.Name.contains("simple-loop-unswitch"); - }); - FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM), UseMemorySSA, - UseBFI)); + FPM.addPass( + createFunctionToLoopPassAdaptor(std::move(LPM), UseMemorySSA)); return Error::success(); } if (Name == "machine-function") { @@ -2248,12 +2243,12 @@ Error PassBuilder::parseFunctionPass(FunctionPassManager &FPM, // The risk is that it may become obsolete if we're not careful. #define LOOPNEST_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ - FPM.addPass(createFunctionToLoopPassAdaptor(CREATE_PASS, false, false)); \ + FPM.addPass(createFunctionToLoopPassAdaptor(CREATE_PASS, false)); \ return Error::success(); \ } #define LOOP_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ - FPM.addPass(createFunctionToLoopPassAdaptor(CREATE_PASS, false, false)); \ + FPM.addPass(createFunctionToLoopPassAdaptor(CREATE_PASS, false)); \ return Error::success(); \ } #define LOOP_PASS_WITH_PARAMS(NAME, CLASS, CREATE_PASS, PARSER, PARAMS) \ @@ -2261,8 +2256,8 @@ Error PassBuilder::parseFunctionPass(FunctionPassManager &FPM, auto Params = parsePassParameters(PARSER, Name, NAME); \ if (!Params) \ return Params.takeError(); \ - FPM.addPass(createFunctionToLoopPassAdaptor(CREATE_PASS(Params.get()), \ - false, false)); \ + FPM.addPass( \ + createFunctionToLoopPassAdaptor(CREATE_PASS(Params.get()), false)); \ return Error::success(); \ } #include "PassRegistry.def" diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index fea0d25..bd03ac0 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -519,16 +519,14 @@ PassBuilder::buildO1FunctionSimplificationPipeline(OptimizationLevel Level, invokeLoopOptimizerEndEPCallbacks(LPM2, Level); FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM1), - /*UseMemorySSA=*/true, - /*UseBlockFrequencyInfo=*/true)); + /*UseMemorySSA=*/true)); FPM.addPass( SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true))); FPM.addPass(InstCombinePass()); // The loop passes in LPM2 (LoopFullUnrollPass) do not preserve MemorySSA. // *All* loop passes must preserve it, in order to be able to use it. FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM2), - /*UseMemorySSA=*/false, - /*UseBlockFrequencyInfo=*/false)); + /*UseMemorySSA=*/false)); // Delete small array after loop unroll. FPM.addPass(SROAPass(SROAOptions::ModifyCFG)); @@ -710,8 +708,7 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level, invokeLoopOptimizerEndEPCallbacks(LPM2, Level); FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM1), - /*UseMemorySSA=*/true, - /*UseBlockFrequencyInfo=*/true)); + /*UseMemorySSA=*/true)); FPM.addPass( SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true))); FPM.addPass(InstCombinePass()); @@ -719,8 +716,7 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level, // LoopDeletionPass and LoopFullUnrollPass) do not preserve MemorySSA. // *All* loop passes must preserve it, in order to be able to use it. FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM2), - /*UseMemorySSA=*/false, - /*UseBlockFrequencyInfo=*/false)); + /*UseMemorySSA=*/false)); // Delete small array after loop unroll. FPM.addPass(SROAPass(SROAOptions::ModifyCFG)); @@ -773,7 +769,7 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level, FPM.addPass(createFunctionToLoopPassAdaptor( LICMPass(PTO.LicmMssaOptCap, PTO.LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/true), - /*UseMemorySSA=*/true, /*UseBlockFrequencyInfo=*/false)); + /*UseMemorySSA=*/true)); FPM.addPass(CoroElidePass()); @@ -842,8 +838,7 @@ void PassBuilder::addPostPGOLoopRotation(ModulePassManager &MPM, createFunctionToLoopPassAdaptor( LoopRotatePass(EnableLoopHeaderDuplication || Level != OptimizationLevel::Oz), - /*UseMemorySSA=*/false, - /*UseBlockFrequencyInfo=*/false), + /*UseMemorySSA=*/false), PTO.EagerlyInvalidateAnalyses)); } } @@ -1358,8 +1353,7 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level, LPM.addPass(SimpleLoopUnswitchPass(/* NonTrivial */ Level == OptimizationLevel::O3)); ExtraPasses.addPass( - createFunctionToLoopPassAdaptor(std::move(LPM), /*UseMemorySSA=*/true, - /*UseBlockFrequencyInfo=*/true)); + createFunctionToLoopPassAdaptor(std::move(LPM), /*UseMemorySSA=*/true)); ExtraPasses.addPass( SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true))); ExtraPasses.addPass(InstCombinePass()); @@ -1438,7 +1432,7 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level, FPM.addPass(createFunctionToLoopPassAdaptor( LICMPass(PTO.LicmMssaOptCap, PTO.LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/true), - /*UseMemorySSA=*/true, /*UseBlockFrequencyInfo=*/false)); + /*UseMemorySSA=*/true)); // Now that we've vectorized and unrolled loops, we may have more refined // alignment information, try to re-derive it here. @@ -1520,7 +1514,7 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level, OptimizePM.addPass(createFunctionToLoopPassAdaptor( LICMPass(PTO.LicmMssaOptCap, PTO.LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/true), - /*USeMemorySSA=*/true, /*UseBlockFrequencyInfo=*/false)); + /*USeMemorySSA=*/true)); } OptimizePM.addPass(Float2IntPass()); @@ -1560,8 +1554,8 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level, if (PTO.LoopInterchange) LPM.addPass(LoopInterchangePass()); - OptimizePM.addPass(createFunctionToLoopPassAdaptor( - std::move(LPM), /*UseMemorySSA=*/false, /*UseBlockFrequencyInfo=*/false)); + OptimizePM.addPass( + createFunctionToLoopPassAdaptor(std::move(LPM), /*UseMemorySSA=*/false)); // FIXME: This may not be the right place in the pipeline. // We need to have the data to support the right place. @@ -1658,6 +1652,16 @@ PassBuilder::buildPerModuleDefaultPipeline(OptimizationLevel Level, ModulePassManager MPM; + // Currently this pipeline is only invoked in an LTO pre link pass or when we + // are not running LTO. If that changes the below checks may need updating. + assert(isLTOPreLink(Phase) || Phase == ThinOrFullLTOPhase::None); + + // If we are invoking this in non-LTO mode, remove any MemProf related + // attributes and metadata, as we don't know whether we are linking with + // a library containing the necessary interfaces. + if (Phase == ThinOrFullLTOPhase::None) + MPM.addPass(MemProfRemoveInfo()); + // Convert @llvm.global.annotations to !annotation metadata. MPM.addPass(Annotation2MetadataPass()); @@ -1803,6 +1807,12 @@ ModulePassManager PassBuilder::buildThinLTODefaultPipeline( OptimizationLevel Level, const ModuleSummaryIndex *ImportSummary) { ModulePassManager MPM; + // If we are invoking this without a summary index noting that we are linking + // with a library containing the necessary APIs, remove any MemProf related + // attributes and metadata. + if (!ImportSummary || !ImportSummary->withSupportsHotColdNew()) + MPM.addPass(MemProfRemoveInfo()); + if (ImportSummary) { // For ThinLTO we must apply the context disambiguation decisions early, to // ensure we can correctly match the callsites to summary data. @@ -1874,6 +1884,12 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, invokeFullLinkTimeOptimizationEarlyEPCallbacks(MPM, Level); + // If we are invoking this without a summary index noting that we are linking + // with a library containing the necessary APIs, remove any MemProf related + // attributes and metadata. + if (!ExportSummary || !ExportSummary->withSupportsHotColdNew()) + MPM.addPass(MemProfRemoveInfo()); + // Create a function that performs CFI checks for cross-DSO calls with targets // in the current module. MPM.addPass(CrossDSOCFIPass()); @@ -2111,7 +2127,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, MainFPM.addPass(createFunctionToLoopPassAdaptor( LICMPass(PTO.LicmMssaOptCap, PTO.LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/true), - /*USeMemorySSA=*/true, /*UseBlockFrequencyInfo=*/false)); + /*USeMemorySSA=*/true)); if (RunNewGVN) MainFPM.addPass(NewGVNPass()); @@ -2141,8 +2157,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, PTO.ForgetAllSCEVInLoopUnroll)); // The loop passes in LPM (LoopFullUnrollPass) do not preserve MemorySSA. // *All* loop passes must preserve it, in order to be able to use it. - MainFPM.addPass(createFunctionToLoopPassAdaptor( - std::move(LPM), /*UseMemorySSA=*/false, /*UseBlockFrequencyInfo=*/true)); + MainFPM.addPass( + createFunctionToLoopPassAdaptor(std::move(LPM), /*UseMemorySSA=*/false)); MainFPM.addPass(LoopDistributePass()); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 1b16525..884d8da 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -113,6 +113,7 @@ MODULE_PASS("pgo-force-function-attrs", ? PGOOpt->ColdOptType : PGOOptions::ColdFuncOpt::Default)) MODULE_PASS("memprof-context-disambiguation", MemProfContextDisambiguation()) +MODULE_PASS("memprof-remove-attributes", MemProfRemoveInfo()) MODULE_PASS("memprof-module", ModuleMemProfilerPass()) MODULE_PASS("mergefunc", MergeFunctionsPass()) MODULE_PASS("metarenamer", MetaRenamerPass()) diff --git a/llvm/lib/Remarks/BitstreamRemarkParser.h b/llvm/lib/Remarks/BitstreamRemarkParser.h index 4f66c47..914edd8 100644 --- a/llvm/lib/Remarks/BitstreamRemarkParser.h +++ b/llvm/lib/Remarks/BitstreamRemarkParser.h @@ -112,7 +112,7 @@ public: /// Helper to parse a META_BLOCK for a bitstream remark container. class BitstreamMetaParserHelper : public BitstreamBlockParserHelper<BitstreamMetaParserHelper> { - friend class BitstreamBlockParserHelper; + friend class BitstreamBlockParserHelper<BitstreamMetaParserHelper>; public: struct ContainerInfo { @@ -137,7 +137,7 @@ protected: /// Helper to parse a REMARK_BLOCK for a bitstream remark container. class BitstreamRemarkParserHelper : public BitstreamBlockParserHelper<BitstreamRemarkParserHelper> { - friend class BitstreamBlockParserHelper; + friend class BitstreamBlockParserHelper<BitstreamRemarkParserHelper>; protected: SmallVector<uint64_t, 5> Record; diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index 8623c06..b4de79a 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -130,44 +130,46 @@ struct fltSemantics { bool hasSignBitInMSB = true; }; -static constexpr fltSemantics semIEEEhalf = {15, -14, 11, 16}; -static constexpr fltSemantics semBFloat = {127, -126, 8, 16}; -static constexpr fltSemantics semIEEEsingle = {127, -126, 24, 32}; -static constexpr fltSemantics semIEEEdouble = {1023, -1022, 53, 64}; -static constexpr fltSemantics semIEEEquad = {16383, -16382, 113, 128}; -static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8}; -static constexpr fltSemantics semFloat8E5M2FNUZ = { +constexpr fltSemantics APFloatBase::semIEEEhalf = {15, -14, 11, 16}; +constexpr fltSemantics APFloatBase::semBFloat = {127, -126, 8, 16}; +constexpr fltSemantics APFloatBase::semIEEEsingle = {127, -126, 24, 32}; +constexpr fltSemantics APFloatBase::semIEEEdouble = {1023, -1022, 53, 64}; +constexpr fltSemantics APFloatBase::semIEEEquad = {16383, -16382, 113, 128}; +constexpr fltSemantics APFloatBase::semFloat8E5M2 = {15, -14, 3, 8}; +constexpr fltSemantics APFloatBase::semFloat8E5M2FNUZ = { 15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; -static constexpr fltSemantics semFloat8E4M3 = {7, -6, 4, 8}; -static constexpr fltSemantics semFloat8E4M3FN = { +constexpr fltSemantics APFloatBase::semFloat8E4M3 = {7, -6, 4, 8}; +constexpr fltSemantics APFloatBase::semFloat8E4M3FN = { 8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes}; -static constexpr fltSemantics semFloat8E4M3FNUZ = { +constexpr fltSemantics APFloatBase::semFloat8E4M3FNUZ = { 7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; -static constexpr fltSemantics semFloat8E4M3B11FNUZ = { +constexpr fltSemantics APFloatBase::semFloat8E4M3B11FNUZ = { 4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; -static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8}; -static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19}; -static constexpr fltSemantics semFloat8E8M0FNU = {127, - -127, - 1, - 8, - fltNonfiniteBehavior::NanOnly, - fltNanEncoding::AllOnes, - false, - false, - false}; - -static constexpr fltSemantics semFloat6E3M2FN = { +constexpr fltSemantics APFloatBase::semFloat8E3M4 = {3, -2, 5, 8}; +constexpr fltSemantics APFloatBase::semFloatTF32 = {127, -126, 11, 19}; +constexpr fltSemantics APFloatBase::semFloat8E8M0FNU = { + 127, + -127, + 1, + 8, + fltNonfiniteBehavior::NanOnly, + fltNanEncoding::AllOnes, + false, + false, + false}; + +constexpr fltSemantics APFloatBase::semFloat6E3M2FN = { 4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly}; -static constexpr fltSemantics semFloat6E2M3FN = { +constexpr fltSemantics APFloatBase::semFloat6E2M3FN = { 2, 0, 4, 6, fltNonfiniteBehavior::FiniteOnly}; -static constexpr fltSemantics semFloat4E2M1FN = { +constexpr fltSemantics APFloatBase::semFloat4E2M1FN = { 2, 0, 2, 4, fltNonfiniteBehavior::FiniteOnly}; -static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80}; -static constexpr fltSemantics semBogus = {0, 0, 0, 0}; -static constexpr fltSemantics semPPCDoubleDouble = {-1, 0, 0, 128}; -static constexpr fltSemantics semPPCDoubleDoubleLegacy = {1023, -1022 + 53, - 53 + 53, 128}; +constexpr fltSemantics APFloatBase::semX87DoubleExtended = {16383, -16382, 64, + 80}; +constexpr fltSemantics APFloatBase::semBogus = {0, 0, 0, 0}; +constexpr fltSemantics APFloatBase::semPPCDoubleDouble = {-1, 0, 0, 128}; +constexpr fltSemantics APFloatBase::semPPCDoubleDoubleLegacy = { + 1023, -1022 + 53, 53 + 53, 128}; const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) { switch (S) { @@ -261,36 +263,6 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) { llvm_unreachable("Unknown floating semantics"); } -const fltSemantics &APFloatBase::IEEEhalf() { return semIEEEhalf; } -const fltSemantics &APFloatBase::BFloat() { return semBFloat; } -const fltSemantics &APFloatBase::IEEEsingle() { return semIEEEsingle; } -const fltSemantics &APFloatBase::IEEEdouble() { return semIEEEdouble; } -const fltSemantics &APFloatBase::IEEEquad() { return semIEEEquad; } -const fltSemantics &APFloatBase::PPCDoubleDouble() { - return semPPCDoubleDouble; -} -const fltSemantics &APFloatBase::PPCDoubleDoubleLegacy() { - return semPPCDoubleDoubleLegacy; -} -const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; } -const fltSemantics &APFloatBase::Float8E5M2FNUZ() { return semFloat8E5M2FNUZ; } -const fltSemantics &APFloatBase::Float8E4M3() { return semFloat8E4M3; } -const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; } -const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; } -const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() { - return semFloat8E4M3B11FNUZ; -} -const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; } -const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; } -const fltSemantics &APFloatBase::Float8E8M0FNU() { return semFloat8E8M0FNU; } -const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; } -const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; } -const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; } -const fltSemantics &APFloatBase::x87DoubleExtended() { - return semX87DoubleExtended; -} -const fltSemantics &APFloatBase::Bogus() { return semBogus; } - bool APFloatBase::isRepresentableBy(const fltSemantics &A, const fltSemantics &B) { return A.maxExponent <= B.maxExponent && A.minExponent >= B.minExponent && @@ -1029,7 +1001,7 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) { // For x87 extended precision, we want to make a NaN, not a // pseudo-NaN. Maybe we should expose the ability to make // pseudo-NaNs? - if (semantics == &semX87DoubleExtended) + if (semantics == &APFloatBase::semX87DoubleExtended) APInt::tcSetBit(significand, QNaNBit + 1); } @@ -1054,7 +1026,7 @@ IEEEFloat &IEEEFloat::operator=(IEEEFloat &&rhs) { category = rhs.category; sign = rhs.sign; - rhs.semantics = &semBogus; + rhs.semantics = &APFloatBase::semBogus; return *this; } @@ -1247,7 +1219,7 @@ IEEEFloat::IEEEFloat(const IEEEFloat &rhs) { assign(rhs); } -IEEEFloat::IEEEFloat(IEEEFloat &&rhs) : semantics(&semBogus) { +IEEEFloat::IEEEFloat(IEEEFloat &&rhs) : semantics(&APFloatBase::semBogus) { *this = std::move(rhs); } @@ -2607,8 +2579,8 @@ APFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, shift = toSemantics.precision - fromSemantics.precision; bool X86SpecialNan = false; - if (&fromSemantics == &semX87DoubleExtended && - &toSemantics != &semX87DoubleExtended && category == fcNaN && + if (&fromSemantics == &APFloatBase::semX87DoubleExtended && + &toSemantics != &APFloatBase::semX87DoubleExtended && category == fcNaN && (!(*significandParts() & 0x8000000000000000ULL) || !(*significandParts() & 0x4000000000000000ULL))) { // x86 has some unusual NaNs which cannot be represented in any other @@ -2694,7 +2666,7 @@ APFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, // For x87 extended precision, we want to make a NaN, not a special NaN if // the input wasn't special either. - if (!X86SpecialNan && semantics == &semX87DoubleExtended) + if (!X86SpecialNan && semantics == &APFloatBase::semX87DoubleExtended) APInt::tcSetBit(significandParts(), semantics->precision - 1); // Convert of sNaN creates qNaN and raises an exception (invalid op). @@ -3530,7 +3502,8 @@ hash_code hash_value(const IEEEFloat &Arg) { // the actual IEEE respresentations. We compensate for that here. APInt IEEEFloat::convertF80LongDoubleAPFloatToAPInt() const { - assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended); + assert(semantics == + (const llvm::fltSemantics *)&APFloatBase::semX87DoubleExtended); assert(partCount()==2); uint64_t myexponent, mysignificand; @@ -3560,7 +3533,8 @@ APInt IEEEFloat::convertF80LongDoubleAPFloatToAPInt() const { } APInt IEEEFloat::convertPPCDoubleDoubleLegacyAPFloatToAPInt() const { - assert(semantics == (const llvm::fltSemantics *)&semPPCDoubleDoubleLegacy); + assert(semantics == + (const llvm::fltSemantics *)&APFloatBase::semPPCDoubleDoubleLegacy); assert(partCount()==2); uint64_t words[2]; @@ -3574,14 +3548,14 @@ APInt IEEEFloat::convertPPCDoubleDoubleLegacyAPFloatToAPInt() const { // Declare fltSemantics before APFloat that uses it (and // saves pointer to it) to ensure correct destruction order. fltSemantics extendedSemantics = *semantics; - extendedSemantics.minExponent = semIEEEdouble.minExponent; + extendedSemantics.minExponent = APFloatBase::semIEEEdouble.minExponent; IEEEFloat extended(*this); fs = extended.convert(extendedSemantics, rmNearestTiesToEven, &losesInfo); assert(fs == opOK && !losesInfo); (void)fs; IEEEFloat u(extended); - fs = u.convert(semIEEEdouble, rmNearestTiesToEven, &losesInfo); + fs = u.convert(APFloatBase::semIEEEdouble, rmNearestTiesToEven, &losesInfo); assert(fs == opOK || fs == opInexact); (void)fs; words[0] = *u.convertDoubleAPFloatToAPInt().getRawData(); @@ -3597,7 +3571,7 @@ APInt IEEEFloat::convertPPCDoubleDoubleLegacyAPFloatToAPInt() const { IEEEFloat v(extended); v.subtract(u, rmNearestTiesToEven); - fs = v.convert(semIEEEdouble, rmNearestTiesToEven, &losesInfo); + fs = v.convert(APFloatBase::semIEEEdouble, rmNearestTiesToEven, &losesInfo); assert(fs == opOK && !losesInfo); (void)fs; words[1] = *v.convertDoubleAPFloatToAPInt().getRawData(); @@ -3611,8 +3585,9 @@ APInt IEEEFloat::convertPPCDoubleDoubleLegacyAPFloatToAPInt() const { template <const fltSemantics &S> APInt IEEEFloat::convertIEEEFloatToAPInt() const { assert(semantics == &S); - const int bias = - (semantics == &semFloat8E8M0FNU) ? -S.minExponent : -(S.minExponent - 1); + const int bias = (semantics == &APFloatBase::semFloat8E8M0FNU) + ? -S.minExponent + : -(S.minExponent - 1); constexpr unsigned int trailing_significand_bits = S.precision - 1; constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth; constexpr integerPart integer_bit = @@ -3677,87 +3652,87 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const { APInt IEEEFloat::convertQuadrupleAPFloatToAPInt() const { assert(partCount() == 2); - return convertIEEEFloatToAPInt<semIEEEquad>(); + return convertIEEEFloatToAPInt<APFloatBase::semIEEEquad>(); } APInt IEEEFloat::convertDoubleAPFloatToAPInt() const { assert(partCount()==1); - return convertIEEEFloatToAPInt<semIEEEdouble>(); + return convertIEEEFloatToAPInt<APFloatBase::semIEEEdouble>(); } APInt IEEEFloat::convertFloatAPFloatToAPInt() const { assert(partCount()==1); - return convertIEEEFloatToAPInt<semIEEEsingle>(); + return convertIEEEFloatToAPInt<APFloatBase::semIEEEsingle>(); } APInt IEEEFloat::convertBFloatAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semBFloat>(); + return convertIEEEFloatToAPInt<APFloatBase::semBFloat>(); } APInt IEEEFloat::convertHalfAPFloatToAPInt() const { assert(partCount()==1); - return convertIEEEFloatToAPInt<semIEEEhalf>(); + return convertIEEEFloatToAPInt<APFloatBase::APFloatBase::semIEEEhalf>(); } APInt IEEEFloat::convertFloat8E5M2APFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E5M2>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E5M2>(); } APInt IEEEFloat::convertFloat8E5M2FNUZAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E5M2FNUZ>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E5M2FNUZ>(); } APInt IEEEFloat::convertFloat8E4M3APFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E4M3>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E4M3>(); } APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E4M3FN>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E4M3FN>(); } APInt IEEEFloat::convertFloat8E4M3FNUZAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E4M3FNUZ>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E4M3FNUZ>(); } APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E4M3B11FNUZ>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E4M3B11FNUZ>(); } APInt IEEEFloat::convertFloat8E3M4APFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E3M4>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E3M4>(); } APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloatTF32>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloatTF32>(); } APInt IEEEFloat::convertFloat8E8M0FNUAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E8M0FNU>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E8M0FNU>(); } APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat6E3M2FN>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat6E3M2FN>(); } APInt IEEEFloat::convertFloat6E2M3FNAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat6E2M3FN>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat6E2M3FN>(); } APInt IEEEFloat::convertFloat4E2M1FNAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat4E2M1FN>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat4E2M1FN>(); } // This function creates an APInt that is just a bit map of the floating @@ -3765,74 +3740,77 @@ APInt IEEEFloat::convertFloat4E2M1FNAPFloatToAPInt() const { // and treating the result as a normal integer is unlikely to be useful. APInt IEEEFloat::bitcastToAPInt() const { - if (semantics == (const llvm::fltSemantics*)&semIEEEhalf) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEhalf) return convertHalfAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semBFloat) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semBFloat) return convertBFloatAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics*)&semIEEEsingle) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEsingle) return convertFloatAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics*)&semIEEEdouble) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEdouble) return convertDoubleAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics*)&semIEEEquad) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEquad) return convertQuadrupleAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semPPCDoubleDoubleLegacy) + if (semantics == + (const llvm::fltSemantics *)&APFloatBase::semPPCDoubleDoubleLegacy) return convertPPCDoubleDoubleLegacyAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E5M2) return convertFloat8E5M2APFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2FNUZ) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E5M2FNUZ) return convertFloat8E5M2FNUZAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E4M3) return convertFloat8E4M3APFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E4M3FN) return convertFloat8E4M3FNAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FNUZ) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E4M3FNUZ) return convertFloat8E4M3FNUZAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ) + if (semantics == + (const llvm::fltSemantics *)&APFloatBase::semFloat8E4M3B11FNUZ) return convertFloat8E4M3B11FNUZAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E3M4) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E3M4) return convertFloat8E3M4APFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloatTF32) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloatTF32) return convertFloatTF32APFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FNU) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E8M0FNU) return convertFloat8E8M0FNUAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat6E3M2FN) return convertFloat6E3M2FNAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat6E2M3FN) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat6E2M3FN) return convertFloat6E2M3FNAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat4E2M1FN) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat4E2M1FN) return convertFloat4E2M1FNAPFloatToAPInt(); - assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended && + assert(semantics == + (const llvm::fltSemantics *)&APFloatBase::semX87DoubleExtended && "unknown format!"); return convertF80LongDoubleAPFloatToAPInt(); } float IEEEFloat::convertToFloat() const { - assert(semantics == (const llvm::fltSemantics*)&semIEEEsingle && + assert(semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEsingle && "Float semantics are not IEEEsingle"); APInt api = bitcastToAPInt(); return api.bitsToFloat(); } double IEEEFloat::convertToDouble() const { - assert(semantics == (const llvm::fltSemantics*)&semIEEEdouble && + assert(semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEdouble && "Float semantics are not IEEEdouble"); APInt api = bitcastToAPInt(); return api.bitsToDouble(); @@ -3840,7 +3818,7 @@ double IEEEFloat::convertToDouble() const { #ifdef HAS_IEE754_FLOAT128 float128 IEEEFloat::convertToQuad() const { - assert(semantics == (const llvm::fltSemantics *)&semIEEEquad && + assert(semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEquad && "Float semantics are not IEEEquads"); APInt api = bitcastToAPInt(); return api.bitsToQuad(); @@ -3861,7 +3839,7 @@ void IEEEFloat::initFromF80LongDoubleAPInt(const APInt &api) { uint64_t mysignificand = i1; uint8_t myintegerbit = mysignificand >> 63; - initialize(&semX87DoubleExtended); + initialize(&APFloatBase::semX87DoubleExtended); assert(partCount()==2); sign = static_cast<unsigned int>(i2>>15); @@ -3893,14 +3871,16 @@ void IEEEFloat::initFromPPCDoubleDoubleLegacyAPInt(const APInt &api) { // Get the first double and convert to our format. initFromDoubleAPInt(APInt(64, i1)); - fs = convert(semPPCDoubleDoubleLegacy, rmNearestTiesToEven, &losesInfo); + fs = convert(APFloatBase::semPPCDoubleDoubleLegacy, rmNearestTiesToEven, + &losesInfo); assert(fs == opOK && !losesInfo); (void)fs; // Unless we have a special case, add in second double. if (isFiniteNonZero()) { - IEEEFloat v(semIEEEdouble, APInt(64, i2)); - fs = v.convert(semPPCDoubleDoubleLegacy, rmNearestTiesToEven, &losesInfo); + IEEEFloat v(APFloatBase::semIEEEdouble, APInt(64, i2)); + fs = v.convert(APFloatBase::semPPCDoubleDoubleLegacy, rmNearestTiesToEven, + &losesInfo); assert(fs == opOK && !losesInfo); (void)fs; @@ -3918,7 +3898,7 @@ void IEEEFloat::initFromFloat8E8M0FNUAPInt(const APInt &api) { uint64_t val = api.getRawData()[0]; uint64_t myexponent = (val & exponent_mask); - initialize(&semFloat8E8M0FNU); + initialize(&APFloatBase::semFloat8E8M0FNU); assert(partCount() == 1); // This format has unsigned representation only @@ -4025,109 +4005,109 @@ void IEEEFloat::initFromIEEEAPInt(const APInt &api) { } void IEEEFloat::initFromQuadrupleAPInt(const APInt &api) { - initFromIEEEAPInt<semIEEEquad>(api); + initFromIEEEAPInt<APFloatBase::semIEEEquad>(api); } void IEEEFloat::initFromDoubleAPInt(const APInt &api) { - initFromIEEEAPInt<semIEEEdouble>(api); + initFromIEEEAPInt<APFloatBase::semIEEEdouble>(api); } void IEEEFloat::initFromFloatAPInt(const APInt &api) { - initFromIEEEAPInt<semIEEEsingle>(api); + initFromIEEEAPInt<APFloatBase::semIEEEsingle>(api); } void IEEEFloat::initFromBFloatAPInt(const APInt &api) { - initFromIEEEAPInt<semBFloat>(api); + initFromIEEEAPInt<APFloatBase::semBFloat>(api); } void IEEEFloat::initFromHalfAPInt(const APInt &api) { - initFromIEEEAPInt<semIEEEhalf>(api); + initFromIEEEAPInt<APFloatBase::semIEEEhalf>(api); } void IEEEFloat::initFromFloat8E5M2APInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E5M2>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E5M2>(api); } void IEEEFloat::initFromFloat8E5M2FNUZAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E5M2FNUZ>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E5M2FNUZ>(api); } void IEEEFloat::initFromFloat8E4M3APInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E4M3>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E4M3>(api); } void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E4M3FN>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E4M3FN>(api); } void IEEEFloat::initFromFloat8E4M3FNUZAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E4M3FNUZ>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E4M3FNUZ>(api); } void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E4M3B11FNUZ>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E4M3B11FNUZ>(api); } void IEEEFloat::initFromFloat8E3M4APInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E3M4>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E3M4>(api); } void IEEEFloat::initFromFloatTF32APInt(const APInt &api) { - initFromIEEEAPInt<semFloatTF32>(api); + initFromIEEEAPInt<APFloatBase::semFloatTF32>(api); } void IEEEFloat::initFromFloat6E3M2FNAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat6E3M2FN>(api); + initFromIEEEAPInt<APFloatBase::semFloat6E3M2FN>(api); } void IEEEFloat::initFromFloat6E2M3FNAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat6E2M3FN>(api); + initFromIEEEAPInt<APFloatBase::semFloat6E2M3FN>(api); } void IEEEFloat::initFromFloat4E2M1FNAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat4E2M1FN>(api); + initFromIEEEAPInt<APFloatBase::semFloat4E2M1FN>(api); } /// Treat api as containing the bits of a floating point number. void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { assert(api.getBitWidth() == Sem->sizeInBits); - if (Sem == &semIEEEhalf) + if (Sem == &APFloatBase::semIEEEhalf) return initFromHalfAPInt(api); - if (Sem == &semBFloat) + if (Sem == &APFloatBase::semBFloat) return initFromBFloatAPInt(api); - if (Sem == &semIEEEsingle) + if (Sem == &APFloatBase::semIEEEsingle) return initFromFloatAPInt(api); - if (Sem == &semIEEEdouble) + if (Sem == &APFloatBase::semIEEEdouble) return initFromDoubleAPInt(api); - if (Sem == &semX87DoubleExtended) + if (Sem == &APFloatBase::semX87DoubleExtended) return initFromF80LongDoubleAPInt(api); - if (Sem == &semIEEEquad) + if (Sem == &APFloatBase::semIEEEquad) return initFromQuadrupleAPInt(api); - if (Sem == &semPPCDoubleDoubleLegacy) + if (Sem == &APFloatBase::semPPCDoubleDoubleLegacy) return initFromPPCDoubleDoubleLegacyAPInt(api); - if (Sem == &semFloat8E5M2) + if (Sem == &APFloatBase::semFloat8E5M2) return initFromFloat8E5M2APInt(api); - if (Sem == &semFloat8E5M2FNUZ) + if (Sem == &APFloatBase::semFloat8E5M2FNUZ) return initFromFloat8E5M2FNUZAPInt(api); - if (Sem == &semFloat8E4M3) + if (Sem == &APFloatBase::semFloat8E4M3) return initFromFloat8E4M3APInt(api); - if (Sem == &semFloat8E4M3FN) + if (Sem == &APFloatBase::semFloat8E4M3FN) return initFromFloat8E4M3FNAPInt(api); - if (Sem == &semFloat8E4M3FNUZ) + if (Sem == &APFloatBase::semFloat8E4M3FNUZ) return initFromFloat8E4M3FNUZAPInt(api); - if (Sem == &semFloat8E4M3B11FNUZ) + if (Sem == &APFloatBase::semFloat8E4M3B11FNUZ) return initFromFloat8E4M3B11FNUZAPInt(api); - if (Sem == &semFloat8E3M4) + if (Sem == &APFloatBase::semFloat8E3M4) return initFromFloat8E3M4APInt(api); - if (Sem == &semFloatTF32) + if (Sem == &APFloatBase::semFloatTF32) return initFromFloatTF32APInt(api); - if (Sem == &semFloat8E8M0FNU) + if (Sem == &APFloatBase::semFloat8E8M0FNU) return initFromFloat8E8M0FNUAPInt(api); - if (Sem == &semFloat6E3M2FN) + if (Sem == &APFloatBase::semFloat6E3M2FN) return initFromFloat6E3M2FNAPInt(api); - if (Sem == &semFloat6E2M3FN) + if (Sem == &APFloatBase::semFloat6E2M3FN) return initFromFloat6E2M3FNAPInt(api); - if (Sem == &semFloat4E2M1FN) + if (Sem == &APFloatBase::semFloat4E2M1FN) return initFromFloat4E2M1FNAPInt(api); llvm_unreachable("unsupported semantics"); @@ -4202,11 +4182,11 @@ IEEEFloat::IEEEFloat(const fltSemantics &Sem, const APInt &API) { } IEEEFloat::IEEEFloat(float f) { - initFromAPInt(&semIEEEsingle, APInt::floatToBits(f)); + initFromAPInt(&APFloatBase::semIEEEsingle, APInt::floatToBits(f)); } IEEEFloat::IEEEFloat(double d) { - initFromAPInt(&semIEEEdouble, APInt::doubleToBits(d)); + initFromAPInt(&APFloatBase::semIEEEdouble, APInt::doubleToBits(d)); } namespace { @@ -4815,38 +4795,40 @@ IEEEFloat frexp(const IEEEFloat &Val, int &Exp, roundingMode RM) { DoubleAPFloat::DoubleAPFloat(const fltSemantics &S) : Semantics(&S), - Floats(new APFloat[2]{APFloat(semIEEEdouble), APFloat(semIEEEdouble)}) { - assert(Semantics == &semPPCDoubleDouble); + Floats(new APFloat[2]{APFloat(APFloatBase::semIEEEdouble), + APFloat(APFloatBase::semIEEEdouble)}) { + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(const fltSemantics &S, uninitializedTag) - : Semantics(&S), - Floats(new APFloat[2]{APFloat(semIEEEdouble, uninitialized), - APFloat(semIEEEdouble, uninitialized)}) { - assert(Semantics == &semPPCDoubleDouble); + : Semantics(&S), Floats(new APFloat[2]{ + APFloat(APFloatBase::semIEEEdouble, uninitialized), + APFloat(APFloatBase::semIEEEdouble, uninitialized)}) { + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(const fltSemantics &S, integerPart I) - : Semantics(&S), Floats(new APFloat[2]{APFloat(semIEEEdouble, I), - APFloat(semIEEEdouble)}) { - assert(Semantics == &semPPCDoubleDouble); + : Semantics(&S), + Floats(new APFloat[2]{APFloat(APFloatBase::semIEEEdouble, I), + APFloat(APFloatBase::semIEEEdouble)}) { + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(const fltSemantics &S, const APInt &I) : Semantics(&S), Floats(new APFloat[2]{ - APFloat(semIEEEdouble, APInt(64, I.getRawData()[0])), - APFloat(semIEEEdouble, APInt(64, I.getRawData()[1]))}) { - assert(Semantics == &semPPCDoubleDouble); + APFloat(APFloatBase::semIEEEdouble, APInt(64, I.getRawData()[0])), + APFloat(APFloatBase::semIEEEdouble, APInt(64, I.getRawData()[1]))}) { + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(const fltSemantics &S, APFloat &&First, APFloat &&Second) : Semantics(&S), Floats(new APFloat[2]{std::move(First), std::move(Second)}) { - assert(Semantics == &semPPCDoubleDouble); - assert(&Floats[0].getSemantics() == &semIEEEdouble); - assert(&Floats[1].getSemantics() == &semIEEEdouble); + assert(Semantics == &APFloatBase::semPPCDoubleDouble); + assert(&Floats[0].getSemantics() == &APFloatBase::semIEEEdouble); + assert(&Floats[1].getSemantics() == &APFloatBase::semIEEEdouble); } DoubleAPFloat::DoubleAPFloat(const DoubleAPFloat &RHS) @@ -4854,14 +4836,14 @@ DoubleAPFloat::DoubleAPFloat(const DoubleAPFloat &RHS) Floats(RHS.Floats ? new APFloat[2]{APFloat(RHS.Floats[0]), APFloat(RHS.Floats[1])} : nullptr) { - assert(Semantics == &semPPCDoubleDouble); + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(DoubleAPFloat &&RHS) : Semantics(RHS.Semantics), Floats(RHS.Floats) { - RHS.Semantics = &semBogus; + RHS.Semantics = &APFloatBase::semBogus; RHS.Floats = nullptr; - assert(Semantics == &semPPCDoubleDouble); + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat &DoubleAPFloat::operator=(const DoubleAPFloat &RHS) { @@ -5009,12 +4991,12 @@ APFloat::opStatus DoubleAPFloat::addWithSpecial(const DoubleAPFloat &LHS, APFloat A(LHS.Floats[0]), AA(LHS.Floats[1]), C(RHS.Floats[0]), CC(RHS.Floats[1]); - assert(&A.getSemantics() == &semIEEEdouble); - assert(&AA.getSemantics() == &semIEEEdouble); - assert(&C.getSemantics() == &semIEEEdouble); - assert(&CC.getSemantics() == &semIEEEdouble); - assert(&Out.Floats[0].getSemantics() == &semIEEEdouble); - assert(&Out.Floats[1].getSemantics() == &semIEEEdouble); + assert(&A.getSemantics() == &APFloatBase::semIEEEdouble); + assert(&AA.getSemantics() == &APFloatBase::semIEEEdouble); + assert(&C.getSemantics() == &APFloatBase::semIEEEdouble); + assert(&CC.getSemantics() == &APFloatBase::semIEEEdouble); + assert(&Out.Floats[0].getSemantics() == &APFloatBase::semIEEEdouble); + assert(&Out.Floats[1].getSemantics() == &APFloatBase::semIEEEdouble); return Out.addImpl(A, AA, C, CC, RM); } @@ -5119,28 +5101,32 @@ APFloat::opStatus DoubleAPFloat::multiply(const DoubleAPFloat &RHS, APFloat::opStatus DoubleAPFloat::divide(const DoubleAPFloat &RHS, APFloat::roundingMode RM) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); - auto Ret = - Tmp.divide(APFloat(semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt()), RM); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()); + auto Ret = Tmp.divide( + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt()), RM); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } APFloat::opStatus DoubleAPFloat::remainder(const DoubleAPFloat &RHS) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); - auto Ret = - Tmp.remainder(APFloat(semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt())); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()); + auto Ret = Tmp.remainder( + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt())); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } APFloat::opStatus DoubleAPFloat::mod(const DoubleAPFloat &RHS) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); - auto Ret = Tmp.mod(APFloat(semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt())); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()); + auto Ret = Tmp.mod( + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt())); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } @@ -5148,17 +5134,21 @@ APFloat::opStatus DoubleAPFloat::fusedMultiplyAdd(const DoubleAPFloat &Multiplicand, const DoubleAPFloat &Addend, APFloat::roundingMode RM) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()); auto Ret = Tmp.fusedMultiplyAdd( - APFloat(semPPCDoubleDoubleLegacy, Multiplicand.bitcastToAPInt()), - APFloat(semPPCDoubleDoubleLegacy, Addend.bitcastToAPInt()), RM); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, + Multiplicand.bitcastToAPInt()), + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, Addend.bitcastToAPInt()), + RM); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } APFloat::opStatus DoubleAPFloat::roundToIntegral(APFloat::roundingMode RM) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); const APFloat &Hi = getFirst(); const APFloat &Lo = getSecond(); @@ -5309,22 +5299,28 @@ void DoubleAPFloat::makeZero(bool Neg) { } void DoubleAPFloat::makeLargest(bool Neg) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - Floats[0] = APFloat(semIEEEdouble, APInt(64, 0x7fefffffffffffffull)); - Floats[1] = APFloat(semIEEEdouble, APInt(64, 0x7c8ffffffffffffeull)); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + Floats[0] = + APFloat(APFloatBase::semIEEEdouble, APInt(64, 0x7fefffffffffffffull)); + Floats[1] = + APFloat(APFloatBase::semIEEEdouble, APInt(64, 0x7c8ffffffffffffeull)); if (Neg) changeSign(); } void DoubleAPFloat::makeSmallest(bool Neg) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); Floats[0].makeSmallest(Neg); Floats[1].makeZero(/* Neg = */ false); } void DoubleAPFloat::makeSmallestNormalized(bool Neg) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - Floats[0] = APFloat(semIEEEdouble, APInt(64, 0x0360000000000000ull)); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + Floats[0] = + APFloat(APFloatBase::semIEEEdouble, APInt(64, 0x0360000000000000ull)); if (Neg) Floats[0].changeSign(); Floats[1].makeZero(/* Neg = */ false); @@ -5355,7 +5351,8 @@ hash_code hash_value(const DoubleAPFloat &Arg) { } APInt DoubleAPFloat::bitcastToAPInt() const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); uint64_t Data[] = { Floats[0].bitcastToAPInt().getRawData()[0], Floats[1].bitcastToAPInt().getRawData()[0], @@ -5365,10 +5362,11 @@ APInt DoubleAPFloat::bitcastToAPInt() const { Expected<APFloat::opStatus> DoubleAPFloat::convertFromString(StringRef S, roundingMode RM) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy); auto Ret = Tmp.convertFromString(S, RM); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } @@ -5379,7 +5377,8 @@ Expected<APFloat::opStatus> DoubleAPFloat::convertFromString(StringRef S, // nextUp must choose the smallest output > input that follows these rules. // nexDown must choose the largest output < input that follows these rules. APFloat::opStatus DoubleAPFloat::next(bool nextDown) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); // nextDown(x) = -nextUp(-x) if (nextDown) { changeSign(); @@ -5481,7 +5480,8 @@ APFloat::opStatus DoubleAPFloat::next(bool nextDown) { APFloat::opStatus DoubleAPFloat::convertToSignExtendedInteger( MutableArrayRef<integerPart> Input, unsigned int Width, bool IsSigned, roundingMode RM, bool *IsExact) const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); // If Hi is not finite, or Lo is zero, the value is entirely represented // by Hi. Delegate to the simpler single-APFloat conversion. @@ -5761,8 +5761,9 @@ unsigned int DoubleAPFloat::convertToHexString(char *DST, unsigned int HexDigits, bool UpperCase, roundingMode RM) const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - return APFloat(semPPCDoubleDoubleLegacy, bitcastToAPInt()) + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + return APFloat(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()) .convertToHexString(DST, HexDigits, UpperCase, RM); } @@ -5799,7 +5800,8 @@ bool DoubleAPFloat::isLargest() const { } bool DoubleAPFloat::isInteger() const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); return Floats[0].isInteger() && Floats[1].isInteger(); } @@ -5807,8 +5809,9 @@ void DoubleAPFloat::toString(SmallVectorImpl<char> &Str, unsigned FormatPrecision, unsigned FormatMaxPadding, bool TruncateZero) const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat(semPPCDoubleDoubleLegacy, bitcastToAPInt()) + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()) .toString(Str, FormatPrecision, FormatMaxPadding, TruncateZero); } @@ -5840,14 +5843,17 @@ int ilogb(const DoubleAPFloat &Arg) { DoubleAPFloat scalbn(const DoubleAPFloat &Arg, int Exp, APFloat::roundingMode RM) { - assert(Arg.Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - return DoubleAPFloat(semPPCDoubleDouble, scalbn(Arg.Floats[0], Exp, RM), + assert(Arg.Semantics == &APFloatBase::PPCDoubleDouble() && + "Unexpected Semantics"); + return DoubleAPFloat(APFloatBase::PPCDoubleDouble(), + scalbn(Arg.Floats[0], Exp, RM), scalbn(Arg.Floats[1], Exp, RM)); } DoubleAPFloat frexp(const DoubleAPFloat &Arg, int &Exp, APFloat::roundingMode RM) { - assert(Arg.Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Arg.Semantics == &APFloatBase::PPCDoubleDouble() && + "Unexpected Semantics"); // Get the unbiased exponent e of the number, where |Arg| = m * 2^e for m in // [1.0, 2.0). @@ -5943,7 +5949,8 @@ DoubleAPFloat frexp(const DoubleAPFloat &Arg, int &Exp, } APFloat First = scalbn(Hi, -Exp, RM); - return DoubleAPFloat(semPPCDoubleDouble, std::move(First), std::move(Second)); + return DoubleAPFloat(APFloatBase::PPCDoubleDouble(), std::move(First), + std::move(Second)); } } // namespace detail @@ -5955,9 +5962,8 @@ APFloat::Storage::Storage(IEEEFloat F, const fltSemantics &Semantics) { } if (usesLayout<DoubleAPFloat>(Semantics)) { const fltSemantics& S = F.getSemantics(); - new (&Double) - DoubleAPFloat(Semantics, APFloat(std::move(F), S), - APFloat(semIEEEdouble)); + new (&Double) DoubleAPFloat(Semantics, APFloat(std::move(F), S), + APFloat(APFloatBase::IEEEdouble())); return; } llvm_unreachable("Unexpected semantics"); @@ -6065,8 +6071,9 @@ APFloat::opStatus APFloat::convert(const fltSemantics &ToSemantics, return U.IEEE.convert(ToSemantics, RM, losesInfo); if (usesLayout<IEEEFloat>(getSemantics()) && usesLayout<DoubleAPFloat>(ToSemantics)) { - assert(&ToSemantics == &semPPCDoubleDouble); - auto Ret = U.IEEE.convert(semPPCDoubleDoubleLegacy, RM, losesInfo); + assert(&ToSemantics == &APFloatBase::semPPCDoubleDouble); + auto Ret = + U.IEEE.convert(APFloatBase::semPPCDoubleDoubleLegacy, RM, losesInfo); *this = APFloat(ToSemantics, U.IEEE.bitcastToAPInt()); return Ret; } @@ -6113,13 +6120,15 @@ APFloat::opStatus APFloat::convertToInteger(APSInt &result, } double APFloat::convertToDouble() const { - if (&getSemantics() == (const llvm::fltSemantics *)&semIEEEdouble) + if (&getSemantics() == + (const llvm::fltSemantics *)&APFloatBase::semIEEEdouble) return getIEEE().convertToDouble(); assert(isRepresentableBy(getSemantics(), semIEEEdouble) && "Float semantics is not representable by IEEEdouble"); APFloat Temp = *this; bool LosesInfo; - opStatus St = Temp.convert(semIEEEdouble, rmNearestTiesToEven, &LosesInfo); + opStatus St = + Temp.convert(APFloatBase::semIEEEdouble, rmNearestTiesToEven, &LosesInfo); assert(!(St & opInexact) && !LosesInfo && "Unexpected imprecision"); (void)St; return Temp.getIEEE().convertToDouble(); @@ -6127,13 +6136,14 @@ double APFloat::convertToDouble() const { #ifdef HAS_IEE754_FLOAT128 float128 APFloat::convertToQuad() const { - if (&getSemantics() == (const llvm::fltSemantics *)&semIEEEquad) + if (&getSemantics() == (const llvm::fltSemantics *)&APFloatBase::semIEEEquad) return getIEEE().convertToQuad(); assert(isRepresentableBy(getSemantics(), semIEEEquad) && "Float semantics is not representable by IEEEquad"); APFloat Temp = *this; bool LosesInfo; - opStatus St = Temp.convert(semIEEEquad, rmNearestTiesToEven, &LosesInfo); + opStatus St = + Temp.convert(APFloatBase::semIEEEquad, rmNearestTiesToEven, &LosesInfo); assert(!(St & opInexact) && !LosesInfo && "Unexpected imprecision"); (void)St; return Temp.getIEEE().convertToQuad(); @@ -6141,18 +6151,84 @@ float128 APFloat::convertToQuad() const { #endif float APFloat::convertToFloat() const { - if (&getSemantics() == (const llvm::fltSemantics *)&semIEEEsingle) + if (&getSemantics() == + (const llvm::fltSemantics *)&APFloatBase::semIEEEsingle) return getIEEE().convertToFloat(); assert(isRepresentableBy(getSemantics(), semIEEEsingle) && "Float semantics is not representable by IEEEsingle"); APFloat Temp = *this; bool LosesInfo; - opStatus St = Temp.convert(semIEEEsingle, rmNearestTiesToEven, &LosesInfo); + opStatus St = + Temp.convert(APFloatBase::semIEEEsingle, rmNearestTiesToEven, &LosesInfo); assert(!(St & opInexact) && !LosesInfo && "Unexpected imprecision"); (void)St; return Temp.getIEEE().convertToFloat(); } +APFloat::Storage::~Storage() { + if (usesLayout<IEEEFloat>(*semantics)) { + IEEE.~IEEEFloat(); + return; + } + if (usesLayout<DoubleAPFloat>(*semantics)) { + Double.~DoubleAPFloat(); + return; + } + llvm_unreachable("Unexpected semantics"); +} + +APFloat::Storage::Storage(const APFloat::Storage &RHS) { + if (usesLayout<IEEEFloat>(*RHS.semantics)) { + new (this) IEEEFloat(RHS.IEEE); + return; + } + if (usesLayout<DoubleAPFloat>(*RHS.semantics)) { + new (this) DoubleAPFloat(RHS.Double); + return; + } + llvm_unreachable("Unexpected semantics"); +} + +APFloat::Storage::Storage(APFloat::Storage &&RHS) { + if (usesLayout<IEEEFloat>(*RHS.semantics)) { + new (this) IEEEFloat(std::move(RHS.IEEE)); + return; + } + if (usesLayout<DoubleAPFloat>(*RHS.semantics)) { + new (this) DoubleAPFloat(std::move(RHS.Double)); + return; + } + llvm_unreachable("Unexpected semantics"); +} + +APFloat::Storage &APFloat::Storage::operator=(const APFloat::Storage &RHS) { + if (usesLayout<IEEEFloat>(*semantics) && + usesLayout<IEEEFloat>(*RHS.semantics)) { + IEEE = RHS.IEEE; + } else if (usesLayout<DoubleAPFloat>(*semantics) && + usesLayout<DoubleAPFloat>(*RHS.semantics)) { + Double = RHS.Double; + } else if (this != &RHS) { + this->~Storage(); + new (this) Storage(RHS); + } + return *this; +} + +APFloat::Storage &APFloat::Storage::operator=(APFloat::Storage &&RHS) { + if (usesLayout<IEEEFloat>(*semantics) && + usesLayout<IEEEFloat>(*RHS.semantics)) { + IEEE = std::move(RHS.IEEE); + } else if (usesLayout<DoubleAPFloat>(*semantics) && + usesLayout<DoubleAPFloat>(*RHS.semantics)) { + Double = std::move(RHS.Double); + } else if (this != &RHS) { + this->~Storage(); + new (this) Storage(std::move(RHS)); + } + return *this; +} + } // namespace llvm #undef APFLOAT_DISPATCH_ON_SEMANTICS diff --git a/llvm/lib/Support/SourceMgr.cpp b/llvm/lib/Support/SourceMgr.cpp index f2bbaab..299615a 100644 --- a/llvm/lib/Support/SourceMgr.cpp +++ b/llvm/lib/Support/SourceMgr.cpp @@ -69,11 +69,11 @@ unsigned SourceMgr::AddIncludeFile(const std::string &Filename, ErrorOr<std::unique_ptr<MemoryBuffer>> SourceMgr::OpenIncludeFile(const std::string &Filename, std::string &IncludedFile) { - if (!FS) - reportFatalInternalError("Opening include file from SourceMgr without VFS"); + auto GetFile = [this](StringRef Path) { + return FS ? FS->getBufferForFile(Path) : MemoryBuffer::getFile(Path); + }; - ErrorOr<std::unique_ptr<MemoryBuffer>> NewBufOrErr = - FS->getBufferForFile(Filename); + ErrorOr<std::unique_ptr<MemoryBuffer>> NewBufOrErr = GetFile(Filename); SmallString<64> Buffer(Filename); // If the file didn't exist directly, see if it's in an include path. @@ -81,7 +81,7 @@ SourceMgr::OpenIncludeFile(const std::string &Filename, ++i) { Buffer = IncludeDirectories[i]; sys::path::append(Buffer, Filename); - NewBufOrErr = FS->getBufferForFile(Buffer); + NewBufOrErr = GetFile(Buffer); } if (NewBufOrErr) diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 1b559a6..8ed4062 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -514,8 +514,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, MVT::i64, Custom); setOperationAction(ISD::SELECT_CC, MVT::i64, Expand); - setOperationAction({ISD::SMIN, ISD::UMIN, ISD::SMAX, ISD::UMAX}, MVT::i32, - Legal); + setOperationAction({ISD::ABS, ISD::SMIN, ISD::UMIN, ISD::SMAX, ISD::UMAX}, + MVT::i32, Legal); setOperationAction( {ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, diff --git a/llvm/lib/Target/AMDGPU/DSInstructions.td b/llvm/lib/Target/AMDGPU/DSInstructions.td index d0ad120..b841171 100644 --- a/llvm/lib/Target/AMDGPU/DSInstructions.td +++ b/llvm/lib/Target/AMDGPU/DSInstructions.td @@ -1488,6 +1488,12 @@ let AssemblerPredicate = isGFX12Plus in { def : MnemonicAlias<"ds_load_tr_b64", "ds_load_tr8_b64">, Requires<[isGFX1250Plus]>; def : MnemonicAlias<"ds_load_tr_b128", "ds_load_tr16_b128">, Requires<[isGFX1250Plus]>; +// Additional aliases for ds load transpose instructions. +def : MnemonicAlias<"ds_load_b64_tr_b8", "ds_load_tr8_b64">, Requires<[isGFX125xOnly]>; +def : MnemonicAlias<"ds_load_b128_tr_b16", "ds_load_tr16_b128">, Requires<[isGFX125xOnly]>; +def : MnemonicAlias<"ds_load_b64_tr_b4", "ds_load_tr4_b64">, Requires<[isGFX125xOnly]>; +def : MnemonicAlias<"ds_load_b96_tr_b6", "ds_load_tr6_b96">, Requires<[isGFX125xOnly]>; + //===----------------------------------------------------------------------===// // GFX11. //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp index e0375ea..e3f3aba 100644 --- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp +++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp @@ -892,6 +892,7 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size, // have EXEC as implicit destination. Issue a warning if encoding for // vdst is not EXEC. if ((MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::VOP3) && + MCII->get(MI.getOpcode()).getNumDefs() == 0 && MCII->get(MI.getOpcode()).hasImplicitDefOfPhysReg(AMDGPU::EXEC)) { auto ExecEncoding = MRI.getEncodingValue(AMDGPU::EXEC_LO); if (Bytes_[0] != ExecEncoding) diff --git a/llvm/lib/Target/AMDGPU/FLATInstructions.td b/llvm/lib/Target/AMDGPU/FLATInstructions.td index 6de59be..8ea64d1 100644 --- a/llvm/lib/Target/AMDGPU/FLATInstructions.td +++ b/llvm/lib/Target/AMDGPU/FLATInstructions.td @@ -3711,6 +3711,12 @@ defm GLOBAL_LOAD_TR_B64_w32 : VFLAT_Real_AllAddr_gfx1250<0x058, "globa defm GLOBAL_LOAD_TR4_B64 : VFLAT_Real_AllAddr_gfx1250<0x073>; defm GLOBAL_LOAD_TR6_B96 : VFLAT_Real_AllAddr_gfx1250<0x074>; +// Additional aliases for global load transpose instructions. +def : MnemonicAlias<"global_load_b128_tr_b16", "global_load_tr16_b128">, Requires<[isGFX125xOnly]>; +def : MnemonicAlias<"global_load_b64_tr_b8", "global_load_tr8_b64">, Requires<[isGFX125xOnly]>; +def : MnemonicAlias<"global_load_b64_tr_b4", "global_load_tr4_b64">, Requires<[isGFX125xOnly]>; +def : MnemonicAlias<"global_load_b96_tr_b6", "global_load_tr6_b96">, Requires<[isGFX125xOnly]>; + defm FLAT_ATOMIC_ADD_F64 : VFLAT_Real_Atomics_gfx1250<0x055>; defm FLAT_ATOMIC_MIN_F64 : VFLAT_Real_Atomics_gfx1250<0x05b, "flat_atomic_min_num_f64">; defm FLAT_ATOMIC_MAX_F64 : VFLAT_Real_Atomics_gfx1250<0x05c, "flat_atomic_max_num_f64">; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index d516330..942e784 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -9072,6 +9072,67 @@ void SIInstrInfo::movePackToVALU(SIInstrWorklist &Worklist, MachineOperand &Src1 = Inst.getOperand(2); const DebugLoc &DL = Inst.getDebugLoc(); + if (ST.useRealTrue16Insts()) { + Register SrcReg0, SrcReg1; + if (!Src0.isReg() || !RI.isVGPR(MRI, Src0.getReg())) { + SrcReg0 = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass); + BuildMI(*MBB, Inst, DL, get(AMDGPU::V_MOV_B32_e32), SrcReg0).add(Src0); + } else { + SrcReg0 = Src0.getReg(); + } + + if (!Src1.isReg() || !RI.isVGPR(MRI, Src1.getReg())) { + SrcReg1 = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass); + BuildMI(*MBB, Inst, DL, get(AMDGPU::V_MOV_B32_e32), SrcReg1).add(Src1); + } else { + SrcReg1 = Src1.getReg(); + } + + bool isSrc0Reg16 = MRI.constrainRegClass(SrcReg0, &AMDGPU::VGPR_16RegClass); + bool isSrc1Reg16 = MRI.constrainRegClass(SrcReg1, &AMDGPU::VGPR_16RegClass); + + auto NewMI = BuildMI(*MBB, Inst, DL, get(AMDGPU::REG_SEQUENCE), ResultReg); + switch (Inst.getOpcode()) { + case AMDGPU::S_PACK_LL_B32_B16: + NewMI + .addReg(SrcReg0, 0, + isSrc0Reg16 ? AMDGPU::NoSubRegister : AMDGPU::lo16) + .addImm(AMDGPU::lo16) + .addReg(SrcReg1, 0, + isSrc1Reg16 ? AMDGPU::NoSubRegister : AMDGPU::lo16) + .addImm(AMDGPU::hi16); + break; + case AMDGPU::S_PACK_LH_B32_B16: + NewMI + .addReg(SrcReg0, 0, + isSrc0Reg16 ? AMDGPU::NoSubRegister : AMDGPU::lo16) + .addImm(AMDGPU::lo16) + .addReg(SrcReg1, 0, AMDGPU::hi16) + .addImm(AMDGPU::hi16); + break; + case AMDGPU::S_PACK_HL_B32_B16: + NewMI.addReg(SrcReg0, 0, AMDGPU::hi16) + .addImm(AMDGPU::lo16) + .addReg(SrcReg1, 0, + isSrc1Reg16 ? AMDGPU::NoSubRegister : AMDGPU::lo16) + .addImm(AMDGPU::hi16); + break; + case AMDGPU::S_PACK_HH_B32_B16: + NewMI.addReg(SrcReg0, 0, AMDGPU::hi16) + .addImm(AMDGPU::lo16) + .addReg(SrcReg1, 0, AMDGPU::hi16) + .addImm(AMDGPU::hi16); + break; + default: + llvm_unreachable("unhandled s_pack_* instruction"); + } + + MachineOperand &Dest = Inst.getOperand(0); + MRI.replaceRegWith(Dest.getReg(), ResultReg); + addUsersToMoveToVALUWorklist(ResultReg, MRI, Worklist); + return; + } + switch (Inst.getOpcode()) { case AMDGPU::S_PACK_LL_B32_B16: { Register ImmReg = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass); @@ -10565,6 +10626,59 @@ bool SIInstrInfo::optimizeCompareInstr(MachineInstr &CmpInstr, Register SrcReg, if (SrcReg2 && !getFoldableImm(SrcReg2, *MRI, CmpValue)) return false; + const auto optimizeCmpSelect = [&CmpInstr, SrcReg, CmpValue, MRI, + this]() -> bool { + if (CmpValue != 0) + return false; + + MachineInstr *Def = MRI->getUniqueVRegDef(SrcReg); + if (!Def || Def->getParent() != CmpInstr.getParent()) + return false; + + bool CanOptimize = false; + + // For S_OP that set SCC = DST!=0, do the transformation + // + // s_cmp_lg_* (S_OP ...), 0 => (S_OP ...) + if (setsSCCifResultIsNonZero(*Def)) + CanOptimize = true; + + // s_cmp_lg_* is redundant because the SCC input value for S_CSELECT* has + // the same value that will be calculated by s_cmp_lg_* + // + // s_cmp_lg_* (S_CSELECT* (non-zero imm), 0), 0 => (S_CSELECT* (non-zero + // imm), 0) + if (Def->getOpcode() == AMDGPU::S_CSELECT_B32 || + Def->getOpcode() == AMDGPU::S_CSELECT_B64) { + bool Op1IsNonZeroImm = + Def->getOperand(1).isImm() && Def->getOperand(1).getImm() != 0; + bool Op2IsZeroImm = + Def->getOperand(2).isImm() && Def->getOperand(2).getImm() == 0; + if (Op1IsNonZeroImm && Op2IsZeroImm) + CanOptimize = true; + } + + if (!CanOptimize) + return false; + + MachineInstr *KillsSCC = nullptr; + for (MachineInstr &MI : + make_range(std::next(Def->getIterator()), CmpInstr.getIterator())) { + if (MI.modifiesRegister(AMDGPU::SCC, &RI)) + return false; + if (MI.killsRegister(AMDGPU::SCC, &RI)) + KillsSCC = &MI; + } + + if (MachineOperand *SccDef = + Def->findRegisterDefOperand(AMDGPU::SCC, /*TRI=*/nullptr)) + SccDef->setIsDead(false); + if (KillsSCC) + KillsSCC->clearRegisterKills(AMDGPU::SCC, /*TRI=*/nullptr); + CmpInstr.eraseFromParent(); + return true; + }; + const auto optimizeCmpAnd = [&CmpInstr, SrcReg, CmpValue, MRI, this](int64_t ExpectedValue, unsigned SrcSize, bool IsReversible, bool IsSigned) -> bool { @@ -10639,16 +10753,20 @@ bool SIInstrInfo::optimizeCompareInstr(MachineInstr &CmpInstr, Register SrcReg, if (IsReversedCC && !MRI->hasOneNonDBGUse(DefReg)) return false; - for (auto I = std::next(Def->getIterator()), E = CmpInstr.getIterator(); - I != E; ++I) { - if (I->modifiesRegister(AMDGPU::SCC, &RI) || - I->killsRegister(AMDGPU::SCC, &RI)) + MachineInstr *KillsSCC = nullptr; + for (MachineInstr &MI : + make_range(std::next(Def->getIterator()), CmpInstr.getIterator())) { + if (MI.modifiesRegister(AMDGPU::SCC, &RI)) return false; + if (MI.killsRegister(AMDGPU::SCC, &RI)) + KillsSCC = &MI; } MachineOperand *SccDef = Def->findRegisterDefOperand(AMDGPU::SCC, /*TRI=*/nullptr); SccDef->setIsDead(false); + if (KillsSCC) + KillsSCC->clearRegisterKills(AMDGPU::SCC, /*TRI=*/nullptr); CmpInstr.eraseFromParent(); if (!MRI->use_nodbg_empty(DefReg)) { @@ -10692,7 +10810,7 @@ bool SIInstrInfo::optimizeCompareInstr(MachineInstr &CmpInstr, Register SrcReg, case AMDGPU::S_CMP_LG_I32: case AMDGPU::S_CMPK_LG_U32: case AMDGPU::S_CMPK_LG_I32: - return optimizeCmpAnd(0, 32, true, false); + return optimizeCmpAnd(0, 32, true, false) || optimizeCmpSelect(); case AMDGPU::S_CMP_GT_U32: case AMDGPU::S_CMPK_GT_U32: return optimizeCmpAnd(0, 32, false, false); @@ -10700,7 +10818,7 @@ bool SIInstrInfo::optimizeCompareInstr(MachineInstr &CmpInstr, Register SrcReg, case AMDGPU::S_CMPK_GT_I32: return optimizeCmpAnd(0, 32, false, true); case AMDGPU::S_CMP_LG_U64: - return optimizeCmpAnd(0, 64, true, false); + return optimizeCmpAnd(0, 64, true, false) || optimizeCmpSelect(); } return false; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h index e979eeb..ee99a74 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h @@ -709,6 +709,30 @@ public: } } + static bool setsSCCifResultIsNonZero(const MachineInstr &MI) { + if (!MI.findRegisterDefOperand(AMDGPU::SCC, /*TRI=*/nullptr)) + return false; + // Compares have no result + if (MI.isCompare()) + return false; + switch (MI.getOpcode()) { + default: + return true; + case AMDGPU::S_ADD_I32: + case AMDGPU::S_ADD_U32: + case AMDGPU::S_ADDC_U32: + case AMDGPU::S_SUB_I32: + case AMDGPU::S_SUB_U32: + case AMDGPU::S_SUBB_U32: + case AMDGPU::S_MIN_I32: + case AMDGPU::S_MIN_U32: + case AMDGPU::S_MAX_I32: + case AMDGPU::S_MAX_U32: + case AMDGPU::S_ADDK_I32: + return false; + } + } + static bool isEXP(const MachineInstr &MI) { return MI.getDesc().TSFlags & SIInstrFlags::EXP; } @@ -879,6 +903,11 @@ public: MI.getOpcode() != AMDGPU::V_ACCVGPR_READ_B32_e64; } + bool isMFMA(uint16_t Opcode) const { + return isMAI(Opcode) && Opcode != AMDGPU::V_ACCVGPR_WRITE_B32_e64 && + Opcode != AMDGPU::V_ACCVGPR_READ_B32_e64; + } + static bool isDOT(const MachineInstr &MI) { return MI.getDesc().TSFlags & SIInstrFlags::IsDOT; } @@ -895,6 +924,10 @@ public: return isMFMA(MI) || isWMMA(MI) || isSWMMAC(MI); } + bool isMFMAorWMMA(uint16_t Opcode) const { + return isMFMA(Opcode) || isWMMA(Opcode) || isSWMMAC(Opcode); + } + static bool isSWMMAC(const MachineInstr &MI) { return MI.getDesc().TSFlags & SIInstrFlags::IsSWMMAC; } diff --git a/llvm/lib/Target/AMDGPU/SIPreEmitPeephole.cpp b/llvm/lib/Target/AMDGPU/SIPreEmitPeephole.cpp index 01a40c1..7431e11 100644 --- a/llvm/lib/Target/AMDGPU/SIPreEmitPeephole.cpp +++ b/llvm/lib/Target/AMDGPU/SIPreEmitPeephole.cpp @@ -47,9 +47,6 @@ private: const MachineBasicBlock &From, const MachineBasicBlock &To) const; bool removeExeczBranch(MachineInstr &MI, MachineBasicBlock &SrcMBB); - // Check if the machine instruction being processed is a supported packed - // instruction. - bool isUnpackingSupportedInstr(MachineInstr &MI) const; // Creates a list of packed instructions following an MFMA that are suitable // for unpacking. void collectUnpackingCandidates(MachineInstr &BeginMI, @@ -454,23 +451,6 @@ bool SIPreEmitPeephole::removeExeczBranch(MachineInstr &MI, return true; } -// If support is extended to new operations, add tests in -// llvm/test/CodeGen/AMDGPU/unpack-non-coissue-insts-post-ra-scheduler.mir. -bool SIPreEmitPeephole::isUnpackingSupportedInstr(MachineInstr &MI) const { - if (!TII->isNeverCoissue(MI)) - return false; - unsigned Opcode = MI.getOpcode(); - switch (Opcode) { - case AMDGPU::V_PK_ADD_F32: - case AMDGPU::V_PK_MUL_F32: - case AMDGPU::V_PK_FMA_F32: - return true; - default: - return false; - } - llvm_unreachable("Fully covered switch"); -} - bool SIPreEmitPeephole::canUnpackingClobberRegister(const MachineInstr &MI) { unsigned OpCode = MI.getOpcode(); Register DstReg = MI.getOperand(0).getReg(); @@ -612,10 +592,13 @@ void SIPreEmitPeephole::collectUnpackingCandidates( for (auto I = std::next(BeginMI.getIterator()); I != E; ++I) { MachineInstr &Instr = *I; + uint16_t UnpackedOpCode = mapToUnpackedOpcode(Instr); + bool IsUnpackable = + !(UnpackedOpCode == std::numeric_limits<uint16_t>::max()); if (Instr.isMetaInstruction()) continue; if ((Instr.isTerminator()) || - (TII->isNeverCoissue(Instr) && !isUnpackingSupportedInstr(Instr)) || + (TII->isNeverCoissue(Instr) && !IsUnpackable) || (SIInstrInfo::modifiesModeRegister(Instr) && Instr.modifiesRegister(AMDGPU::EXEC, TRI))) return; @@ -639,7 +622,7 @@ void SIPreEmitPeephole::collectUnpackingCandidates( if (TRI->regsOverlap(MFMADef, InstrMO.getReg())) return; } - if (!isUnpackingSupportedInstr(Instr)) + if (!IsUnpackable) continue; if (canUnpackingClobberRegister(Instr)) @@ -687,8 +670,8 @@ MachineInstrBuilder SIPreEmitPeephole::createUnpackedMI(MachineInstr &I, bool IsHiBits) { MachineBasicBlock &MBB = *I.getParent(); const DebugLoc &DL = I.getDebugLoc(); - const MachineOperand *SrcMO1 = TII->getNamedOperand(I, AMDGPU::OpName::src0); - const MachineOperand *SrcMO2 = TII->getNamedOperand(I, AMDGPU::OpName::src1); + const MachineOperand *SrcMO0 = TII->getNamedOperand(I, AMDGPU::OpName::src0); + const MachineOperand *SrcMO1 = TII->getNamedOperand(I, AMDGPU::OpName::src1); Register DstReg = I.getOperand(0).getReg(); unsigned OpCode = I.getOpcode(); Register UnpackedDstReg = IsHiBits ? TRI->getSubReg(DstReg, AMDGPU::sub1) @@ -702,15 +685,15 @@ MachineInstrBuilder SIPreEmitPeephole::createUnpackedMI(MachineInstr &I, MachineInstrBuilder NewMI = BuildMI(MBB, I, DL, TII->get(UnpackedOpcode)); NewMI.addDef(UnpackedDstReg); // vdst - addOperandAndMods(NewMI, Src0Mods, IsHiBits, *SrcMO1); - addOperandAndMods(NewMI, Src1Mods, IsHiBits, *SrcMO2); + addOperandAndMods(NewMI, Src0Mods, IsHiBits, *SrcMO0); + addOperandAndMods(NewMI, Src1Mods, IsHiBits, *SrcMO1); if (AMDGPU::hasNamedOperand(OpCode, AMDGPU::OpName::src2)) { - const MachineOperand *SrcMO3 = + const MachineOperand *SrcMO2 = TII->getNamedOperand(I, AMDGPU::OpName::src2); unsigned Src2Mods = TII->getNamedOperand(I, AMDGPU::OpName::src2_modifiers)->getImm(); - addOperandAndMods(NewMI, Src2Mods, IsHiBits, *SrcMO3); + addOperandAndMods(NewMI, Src2Mods, IsHiBits, *SrcMO2); } NewMI.addImm(ClampVal); // clamp // Packed instructions do not support output modifiers. safe to assign them 0 @@ -787,9 +770,13 @@ bool SIPreEmitPeephole::run(MachineFunction &MF) { // TODO: Fold this into previous block, if possible. Evaluate and handle any // side effects. + + // Perform the extra MF scans only for supported archs + if (!ST.hasGFX940Insts()) + return Changed; for (MachineBasicBlock &MBB : MF) { - // Unpack packed instructions overlapped by MFMAs. This allows the compiler - // to co-issue unpacked instructions with MFMA + // Unpack packed instructions overlapped by MFMAs. This allows the + // compiler to co-issue unpacked instructions with MFMA auto SchedModel = TII->getSchedModel(); SetVector<MachineInstr *> InstrsToUnpack; for (auto &MI : make_early_inc_range(MBB.instrs())) { diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 67ea2dd..35e1127 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -21287,21 +21287,28 @@ bool ARMTargetLowering::useLoadStackGuardNode(const Module &M) const { } void ARMTargetLowering::insertSSPDeclarations(Module &M) const { + // MSVC CRT provides functionalities for stack protection. RTLIB::LibcallImpl SecurityCheckCookieLibcall = getLibcallImpl(RTLIB::SECURITY_CHECK_COOKIE); - if (SecurityCheckCookieLibcall == RTLIB::Unsupported) - return TargetLowering::insertSSPDeclarations(M); - // MSVC CRT has a global variable holding security cookie. - M.getOrInsertGlobal("__security_cookie", - PointerType::getUnqual(M.getContext())); + RTLIB::LibcallImpl SecurityCookieVar = + getLibcallImpl(RTLIB::STACK_CHECK_GUARD); + if (SecurityCheckCookieLibcall != RTLIB::Unsupported && + SecurityCookieVar != RTLIB::Unsupported) { + // MSVC CRT has a global variable holding security cookie. + M.getOrInsertGlobal(getLibcallImplName(SecurityCookieVar), + PointerType::getUnqual(M.getContext())); - // MSVC CRT has a function to validate security cookie. - FunctionCallee SecurityCheckCookie = M.getOrInsertFunction( - getLibcallImplName(SecurityCheckCookieLibcall), - Type::getVoidTy(M.getContext()), PointerType::getUnqual(M.getContext())); - if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) - F->addParamAttr(0, Attribute::AttrKind::InReg); + // MSVC CRT has a function to validate security cookie. + FunctionCallee SecurityCheckCookie = + M.getOrInsertFunction(getLibcallImplName(SecurityCheckCookieLibcall), + Type::getVoidTy(M.getContext()), + PointerType::getUnqual(M.getContext())); + if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) + F->addParamAttr(0, Attribute::AttrKind::InReg); + } + + TargetLowering::insertSSPDeclarations(M); } Function *ARMTargetLowering::getSSPStackGuardCheck(const Module &M) const { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 272c21f..2f1a7ad 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -749,7 +749,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTruncStoreAction(VT, MVT::i1, Expand); } - // Disable generations of extload/truncstore for v2i16/v2i8. The generic + // Disable generations of extload/truncstore for v2i32/v2i16/v2i8. The generic // expansion for these nodes when they are unaligned is incorrect if the // type is a vector. // @@ -757,7 +757,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // TargetLowering::expandUnalignedLoad/Store. setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16, MVT::v2i8, Expand); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i32, + {MVT::v2i8, MVT::v2i16}, Expand); setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand); + setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand); + setTruncStoreAction(MVT::v2i32, MVT::v2i8, Expand); // Register custom handling for illegal type loads/stores. We'll try to custom // lower almost all illegal types and logic in the lowering will discard cases diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index 5ceb477..19992e6 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -695,6 +695,9 @@ def HasStdExtZvfbfa : Predicate<"Subtarget->hasStdExtZvfbfa()">, def FeatureStdExtZvfbfmin : RISCVExtension<1, 0, "Vector BF16 Converts", [FeatureStdExtZve32f]>; +def HasStdExtZvfbfmin : Predicate<"Subtarget->hasStdExtZvfbfmin()">, + AssemblerPredicate<(all_of FeatureStdExtZvfbfmin), + "'Zvfbfmin' (Vector BF16 Converts)">; def FeatureStdExtZvfbfwma : RISCVExtension<1, 0, "Vector BF16 widening mul-add", diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index eb87558..169465e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -24830,7 +24830,8 @@ bool RISCVTargetLowering::isIntDivCheap(EVT VT, AttributeList Attr) const { // instruction, as it is usually smaller than the alternative sequence. // TODO: Add vector division? bool OptSize = Attr.hasFnAttr(Attribute::MinSize); - return OptSize && !VT.isVector(); + return OptSize && !VT.isVector() && + VT.getSizeInBits() <= getMaxDivRemBitWidthSupported(); } bool RISCVTargetLowering::preferScalarizeSplat(SDNode *N) const { diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp index 1b7cb9b..636e31c 100644 --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -699,7 +699,8 @@ public: "Can't encode VTYPE for uninitialized or unknown"); if (TWiden != 0) return RISCVVType::encodeXSfmmVType(SEW, TWiden, AltFmt); - return RISCVVType::encodeVTYPE(VLMul, SEW, TailAgnostic, MaskAgnostic); + return RISCVVType::encodeVTYPE(VLMul, SEW, TailAgnostic, MaskAgnostic, + AltFmt); } bool hasSEWLMULRatioOnly() const { return SEWLMULRatioOnly; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index ddb53a2..12f776b 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -3775,11 +3775,13 @@ std::string RISCVInstrInfo::createMIROperandComment( #define CASE_VFMA_OPCODE_VV(OP) \ CASE_VFMA_OPCODE_LMULS_MF4(OP, VV, E16): \ + case CASE_VFMA_OPCODE_LMULS_MF4(OP##_ALT, VV, E16): \ case CASE_VFMA_OPCODE_LMULS_MF2(OP, VV, E32): \ case CASE_VFMA_OPCODE_LMULS_M1(OP, VV, E64) #define CASE_VFMA_SPLATS(OP) \ CASE_VFMA_OPCODE_LMULS_MF4(OP, VFPR16, E16): \ + case CASE_VFMA_OPCODE_LMULS_MF4(OP##_ALT, VFPR16, E16): \ case CASE_VFMA_OPCODE_LMULS_MF2(OP, VFPR32, E32): \ case CASE_VFMA_OPCODE_LMULS_M1(OP, VFPR64, E64) // clang-format on @@ -4003,11 +4005,13 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI, #define CASE_VFMA_CHANGE_OPCODE_VV(OLDOP, NEWOP) \ CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP, NEWOP, VV, E16) \ + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP##_ALT, NEWOP##_ALT, VV, E16) \ CASE_VFMA_CHANGE_OPCODE_LMULS_MF2(OLDOP, NEWOP, VV, E32) \ CASE_VFMA_CHANGE_OPCODE_LMULS_M1(OLDOP, NEWOP, VV, E64) #define CASE_VFMA_CHANGE_OPCODE_SPLATS(OLDOP, NEWOP) \ CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP, NEWOP, VFPR16, E16) \ + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP##_ALT, NEWOP##_ALT, VFPR16, E16) \ CASE_VFMA_CHANGE_OPCODE_LMULS_MF2(OLDOP, NEWOP, VFPR32, E32) \ CASE_VFMA_CHANGE_OPCODE_LMULS_M1(OLDOP, NEWOP, VFPR64, E64) // clang-format on @@ -4469,6 +4473,20 @@ bool RISCVInstrInfo::simplifyInstruction(MachineInstr &MI) const { CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2, E32) \ CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4, E16) \ CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4, E32) \ + +#define CASE_FP_WIDEOP_OPCODE_LMULS_ALT(OP) \ + CASE_FP_WIDEOP_OPCODE_COMMON(OP, MF4, E16): \ + case CASE_FP_WIDEOP_OPCODE_COMMON(OP, MF2, E16): \ + case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M1, E16): \ + case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M2, E16): \ + case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M4, E16) + +#define CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS_ALT(OP) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF4, E16) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF2, E16) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M1, E16) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2, E16) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4, E16) // clang-format on MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, @@ -4478,6 +4496,8 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, switch (MI.getOpcode()) { default: return nullptr; + case CASE_FP_WIDEOP_OPCODE_LMULS_ALT(FWADD_ALT_WV): + case CASE_FP_WIDEOP_OPCODE_LMULS_ALT(FWSUB_ALT_WV): case CASE_FP_WIDEOP_OPCODE_LMULS(FWADD_WV): case CASE_FP_WIDEOP_OPCODE_LMULS(FWSUB_WV): { assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags) && @@ -4494,6 +4514,8 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, llvm_unreachable("Unexpected opcode"); CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS(FWADD_WV) CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS(FWSUB_WV) + CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS_ALT(FWADD_ALT_WV) + CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS_ALT(FWSUB_ALT_WV) } // clang-format on diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td index 65865ce..eb3c9b0 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -5862,20 +5862,6 @@ multiclass VPatConversionWF_VF<string intrinsic, string instruction, } } -multiclass VPatConversionWF_VF_BF<string intrinsic, string instruction, - bit isSEWAware = 0> { - foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in - { - defvar fvti = fvtiToFWti.Vti; - defvar fwti = fvtiToFWti.Wti; - let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates, - GetVTypePredicates<fwti>.Predicates) in - defm : VPatConversion<intrinsic, instruction, "V", - fwti.Vector, fvti.Vector, fwti.Mask, fvti.Log2SEW, - fvti.LMul, fwti.RegClass, fvti.RegClass, isSEWAware>; - } -} - multiclass VPatConversionVI_WF<string intrinsic, string instruction> { foreach vtiToWti = AllWidenableIntToFloatVectors in { defvar vti = vtiToWti.Vti; @@ -5969,20 +5955,6 @@ multiclass VPatConversionVF_WF_RTZ<string intrinsic, string instruction, } } -multiclass VPatConversionVF_WF_BF_RM<string intrinsic, string instruction, - bit isSEWAware = 0> { - foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { - defvar fvti = fvtiToFWti.Vti; - defvar fwti = fvtiToFWti.Wti; - let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates, - GetVTypePredicates<fwti>.Predicates) in - defm : VPatConversionRoundingMode<intrinsic, instruction, "W", - fvti.Vector, fwti.Vector, fvti.Mask, fvti.Log2SEW, - fvti.LMul, fvti.RegClass, fwti.RegClass, - isSEWAware>; - } -} - multiclass VPatCompare_VI<string intrinsic, string inst, ImmLeaf ImmType> { foreach vti = AllIntegerVectors in { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td index 0be9eab..9358486 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td @@ -36,7 +36,7 @@ defm VFWMACCBF16_V : VWMAC_FV_V_F<"vfwmaccbf16", 0b111011>; //===----------------------------------------------------------------------===// // Pseudo instructions //===----------------------------------------------------------------------===// -let Predicates = [HasStdExtZvfbfminOrZvfofp8min] in { +let Predicates = [HasStdExtZvfbfmin] in { defm PseudoVFWCVTBF16_F_F : VPseudoVWCVTD_V; defm PseudoVFNCVTBF16_F_F : VPseudoVNCVTD_W_RM; } @@ -44,10 +44,364 @@ let Predicates = [HasStdExtZvfbfminOrZvfofp8min] in { let mayRaiseFPException = true, Predicates = [HasStdExtZvfbfwma] in defm PseudoVFWMACCBF16 : VPseudoVWMAC_VV_VF_BF_RM; +defset list<VTypeInfoToWide> AllWidenableIntToBF16Vectors = { + def : VTypeInfoToWide<VI8MF8, VBF16MF4>; + def : VTypeInfoToWide<VI8MF4, VBF16MF2>; + def : VTypeInfoToWide<VI8MF2, VBF16M1>; + def : VTypeInfoToWide<VI8M1, VBF16M2>; + def : VTypeInfoToWide<VI8M2, VBF16M4>; + def : VTypeInfoToWide<VI8M4, VBF16M8>; +} + +multiclass VPseudoVALU_VV_VF_RM_BF16 { + foreach m = MxListF in { + defm "" : VPseudoBinaryFV_VV_RM<m, 16/*sew*/>, + SchedBinary<"WriteVFALUV", "ReadVFALUV", "ReadVFALUV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF_RM<m, f, f.SEW>, + SchedBinary<"WriteVFALUF", "ReadVFALUV", "ReadVFALUF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVALU_VF_RM_BF16 { + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF_RM<m, f, f.SEW>, + SchedBinary<"WriteVFALUF", "ReadVFALUV", "ReadVFALUF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVFWALU_VV_VF_RM_BF16 { + foreach m = MxListFW in { + defm "" : VPseudoBinaryW_VV_RM<m, sew=16>, + SchedBinary<"WriteVFWALUV", "ReadVFWALUV", "ReadVFWALUV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxListFW in { + defm "" : VPseudoBinaryW_VF_RM<m, f, sew=f.SEW>, + SchedBinary<"WriteVFWALUF", "ReadVFWALUV", "ReadVFWALUF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVFWALU_WV_WF_RM_BF16 { + foreach m = MxListFW in { + defm "" : VPseudoBinaryW_WV_RM<m, sew=16>, + SchedBinary<"WriteVFWALUV", "ReadVFWALUV", "ReadVFWALUV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + defvar f = SCALAR_F16; + foreach m = f.MxListFW in { + defm "" : VPseudoBinaryW_WF_RM<m, f, sew=f.SEW>, + SchedBinary<"WriteVFWALUF", "ReadVFWALUV", "ReadVFWALUF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVFMUL_VV_VF_RM_BF16 { + foreach m = MxListF in { + defm "" : VPseudoBinaryFV_VV_RM<m, 16/*sew*/>, + SchedBinary<"WriteVFMulV", "ReadVFMulV", "ReadVFMulV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF_RM<m, f, f.SEW>, + SchedBinary<"WriteVFMulF", "ReadVFMulV", "ReadVFMulF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVWMUL_VV_VF_RM_BF16 { + foreach m = MxListFW in { + defm "" : VPseudoBinaryW_VV_RM<m, sew=16>, + SchedBinary<"WriteVFWMulV", "ReadVFWMulV", "ReadVFWMulV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxListFW in { + defm "" : VPseudoBinaryW_VF_RM<m, f, sew=f.SEW>, + SchedBinary<"WriteVFWMulF", "ReadVFWMulV", "ReadVFWMulF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVMAC_VV_VF_AAXA_RM_BF16 { + foreach m = MxListF in { + defm "" : VPseudoTernaryV_VV_AAXA_RM<m, 16/*sew*/>, + SchedTernary<"WriteVFMulAddV", "ReadVFMulAddV", "ReadVFMulAddV", + "ReadVFMulAddV", m.MX, 16/*sew*/>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoTernaryV_VF_AAXA_RM<m, f, f.SEW>, + SchedTernary<"WriteVFMulAddF", "ReadVFMulAddV", "ReadVFMulAddF", + "ReadVFMulAddV", m.MX, f.SEW>; + } +} + +multiclass VPseudoVWMAC_VV_VF_RM_BF16 { + foreach m = MxListFW in { + defm "" : VPseudoTernaryW_VV_RM<m, sew=16>, + SchedTernary<"WriteVFWMulAddV", "ReadVFWMulAddV", + "ReadVFWMulAddV", "ReadVFWMulAddV", m.MX, 16/*sew*/>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxListFW in { + defm "" : VPseudoTernaryW_VF_RM<m, f, sew=f.SEW>, + SchedTernary<"WriteVFWMulAddF", "ReadVFWMulAddV", + "ReadVFWMulAddF", "ReadVFWMulAddV", m.MX, f.SEW>; + } +} + +multiclass VPseudoVRCP_V_BF16 { + foreach m = MxListF in { + defvar mx = m.MX; + let VLMul = m.value in { + def "_V_" # mx # "_E16" + : VPseudoUnaryNoMask<m.vrclass, m.vrclass>, + SchedUnary<"WriteVFRecpV", "ReadVFRecpV", mx, 16/*sew*/, + forcePassthruRead=true>; + def "_V_" # mx # "_E16_MASK" + : VPseudoUnaryMask<m.vrclass, m.vrclass>, + RISCVMaskedPseudo<MaskIdx = 2>, + SchedUnary<"WriteVFRecpV", "ReadVFRecpV", mx, 16/*sew*/, + forcePassthruRead=true>; + } + } +} + +multiclass VPseudoVRCP_V_RM_BF16 { + foreach m = MxListF in { + defvar mx = m.MX; + let VLMul = m.value in { + def "_V_" # mx # "_E16" + : VPseudoUnaryNoMaskRoundingMode<m.vrclass, m.vrclass>, + SchedUnary<"WriteVFRecpV", "ReadVFRecpV", mx, 16/*sew*/, + forcePassthruRead=true>; + def "_V_" # mx # "_E16_MASK" + : VPseudoUnaryMaskRoundingMode<m.vrclass, m.vrclass>, + RISCVMaskedPseudo<MaskIdx = 2>, + SchedUnary<"WriteVFRecpV", "ReadVFRecpV", mx, 16/*sew*/, + forcePassthruRead=true>; + } + } +} + +multiclass VPseudoVMAX_VV_VF_BF16 { + foreach m = MxListF in { + defm "" : VPseudoBinaryV_VV<m, sew=16>, + SchedBinary<"WriteVFMinMaxV", "ReadVFMinMaxV", "ReadVFMinMaxV", + m.MX, 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF<m, f, f.SEW>, + SchedBinary<"WriteVFMinMaxF", "ReadVFMinMaxV", "ReadVFMinMaxF", + m.MX, f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVSGNJ_VV_VF_BF16 { + foreach m = MxListF in { + defm "" : VPseudoBinaryV_VV<m, sew=16>, + SchedBinary<"WriteVFSgnjV", "ReadVFSgnjV", "ReadVFSgnjV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF<m, f, f.SEW>, + SchedBinary<"WriteVFSgnjF", "ReadVFSgnjV", "ReadVFSgnjF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVWCVTF_V_BF16 { + defvar constraint = "@earlyclobber $rd"; + foreach m = MxListW in + defm _V : VPseudoConversion<m.wvrclass, m.vrclass, m, constraint, sew=8, + TargetConstraintType=3>, + SchedUnary<"WriteVFWCvtIToFV", "ReadVFWCvtIToFV", m.MX, 8/*sew*/, + forcePassthruRead=true>; +} + +multiclass VPseudoVWCVTD_V_BF16 { + defvar constraint = "@earlyclobber $rd"; + foreach m = MxListFW in + defm _V : VPseudoConversion<m.wvrclass, m.vrclass, m, constraint, sew=16, + TargetConstraintType=3>, + SchedUnary<"WriteVFWCvtFToFV", "ReadVFWCvtFToFV", m.MX, 16/*sew*/, + forcePassthruRead=true>; +} + +multiclass VPseudoVNCVTD_W_BF16 { + defvar constraint = "@earlyclobber $rd"; + foreach m = MxListFW in + defm _W : VPseudoConversion<m.vrclass, m.wvrclass, m, constraint, sew=16, + TargetConstraintType=2>, + SchedUnary<"WriteVFNCvtFToFV", "ReadVFNCvtFToFV", m.MX, 16/*sew*/, + forcePassthruRead=true>; +} + +multiclass VPseudoVNCVTD_W_RM_BF16 { + defvar constraint = "@earlyclobber $rd"; + foreach m = MxListFW in + defm _W : VPseudoConversionRoundingMode<m.vrclass, m.wvrclass, m, + constraint, sew=16, + TargetConstraintType=2>, + SchedUnary<"WriteVFNCvtFToFV", "ReadVFNCvtFToFV", m.MX, 16/*sew*/, + forcePassthruRead=true>; +} + +let Predicates = [HasStdExtZvfbfa], AltFmtType = IS_ALTFMT in { +let mayRaiseFPException = true in { +defm PseudoVFADD_ALT : VPseudoVALU_VV_VF_RM_BF16; +defm PseudoVFSUB_ALT : VPseudoVALU_VV_VF_RM_BF16; +defm PseudoVFRSUB_ALT : VPseudoVALU_VF_RM_BF16; +} + +let mayRaiseFPException = true in { +defm PseudoVFWADD_ALT : VPseudoVFWALU_VV_VF_RM_BF16; +defm PseudoVFWSUB_ALT : VPseudoVFWALU_VV_VF_RM_BF16; +defm PseudoVFWADD_ALT : VPseudoVFWALU_WV_WF_RM_BF16; +defm PseudoVFWSUB_ALT : VPseudoVFWALU_WV_WF_RM_BF16; +} + +let mayRaiseFPException = true in +defm PseudoVFMUL_ALT : VPseudoVFMUL_VV_VF_RM_BF16; + +let mayRaiseFPException = true in +defm PseudoVFWMUL_ALT : VPseudoVWMUL_VV_VF_RM_BF16; + +let mayRaiseFPException = true in { +defm PseudoVFMACC_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFNMACC_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFMSAC_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFNMSAC_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFMADD_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFNMADD_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFMSUB_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFNMSUB_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +} + +let mayRaiseFPException = true in { +defm PseudoVFWMACC_ALT : VPseudoVWMAC_VV_VF_RM_BF16; +defm PseudoVFWNMACC_ALT : VPseudoVWMAC_VV_VF_RM_BF16; +defm PseudoVFWMSAC_ALT : VPseudoVWMAC_VV_VF_RM_BF16; +defm PseudoVFWNMSAC_ALT : VPseudoVWMAC_VV_VF_RM_BF16; +} + +let mayRaiseFPException = true in +defm PseudoVFRSQRT7_ALT : VPseudoVRCP_V_BF16; + +let mayRaiseFPException = true in +defm PseudoVFREC7_ALT : VPseudoVRCP_V_RM_BF16; + +let mayRaiseFPException = true in { +defm PseudoVFMIN_ALT : VPseudoVMAX_VV_VF_BF16; +defm PseudoVFMAX_ALT : VPseudoVMAX_VV_VF_BF16; +} + +defm PseudoVFSGNJ_ALT : VPseudoVSGNJ_VV_VF_BF16; +defm PseudoVFSGNJN_ALT : VPseudoVSGNJ_VV_VF_BF16; +defm PseudoVFSGNJX_ALT : VPseudoVSGNJ_VV_VF_BF16; + +let mayRaiseFPException = true in { +defm PseudoVMFEQ_ALT : VPseudoVCMPM_VV_VF; +defm PseudoVMFNE_ALT : VPseudoVCMPM_VV_VF; +defm PseudoVMFLT_ALT : VPseudoVCMPM_VV_VF; +defm PseudoVMFLE_ALT : VPseudoVCMPM_VV_VF; +defm PseudoVMFGT_ALT : VPseudoVCMPM_VF; +defm PseudoVMFGE_ALT : VPseudoVCMPM_VF; +} + +defm PseudoVFCLASS_ALT : VPseudoVCLS_V; + +defm PseudoVFMERGE_ALT : VPseudoVMRG_FM; + +defm PseudoVFMV_V_ALT : VPseudoVMV_F; + +let mayRaiseFPException = true in { +defm PseudoVFWCVT_F_XU_ALT : VPseudoVWCVTF_V_BF16; +defm PseudoVFWCVT_F_X_ALT : VPseudoVWCVTF_V_BF16; + +defm PseudoVFWCVT_F_F_ALT : VPseudoVWCVTD_V_BF16; +} // mayRaiseFPException = true + +let mayRaiseFPException = true in { +let hasSideEffects = 0, hasPostISelHook = 1 in { +defm PseudoVFNCVT_XU_F_ALT : VPseudoVNCVTI_W_RM; +defm PseudoVFNCVT_X_F_ALT : VPseudoVNCVTI_W_RM; +} + +defm PseudoVFNCVT_RTZ_XU_F_ALT : VPseudoVNCVTI_W; +defm PseudoVFNCVT_RTZ_X_F_ALT : VPseudoVNCVTI_W; + +defm PseudoVFNCVT_F_F_ALT : VPseudoVNCVTD_W_RM_BF16; + +defm PseudoVFNCVT_ROD_F_F_ALT : VPseudoVNCVTD_W_BF16; +} // mayRaiseFPException = true + +let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in { + defvar f = SCALAR_F16; + let HasSEWOp = 1, BaseInstr = VFMV_F_S in + def "PseudoVFMV_" # f.FX # "_S_ALT" : + RISCVVPseudo<(outs f.fprclass:$rd), (ins VR:$rs2, sew:$sew)>, + Sched<[WriteVMovFS, ReadVMovFS]>; + let HasVLOp = 1, HasSEWOp = 1, BaseInstr = VFMV_S_F, isReMaterializable = 1, + Constraints = "$rd = $passthru" in + def "PseudoVFMV_S_" # f.FX # "_ALT" : + RISCVVPseudo<(outs VR:$rd), + (ins VR:$passthru, f.fprclass:$rs1, AVL:$vl, sew:$sew)>, + Sched<[WriteVMovSF, ReadVMovSF_V, ReadVMovSF_F]>; +} + +defm PseudoVFSLIDE1UP_ALT : VPseudoVSLD1_VF<"@earlyclobber $rd">; +defm PseudoVFSLIDE1DOWN_ALT : VPseudoVSLD1_VF; +} // Predicates = [HasStdExtZvfbfa], AltFmtType = IS_ALTFMT + //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// -let Predicates = [HasStdExtZvfbfminOrZvfofp8min] in { +multiclass VPatConversionWF_VF_BF<string intrinsic, string instruction, + bit isSEWAware = 0> { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in + { + defvar fvti = fvtiToFWti.Vti; + defvar fwti = fvtiToFWti.Wti; + defm : VPatConversion<intrinsic, instruction, "V", + fwti.Vector, fvti.Vector, fwti.Mask, fvti.Log2SEW, + fvti.LMul, fwti.RegClass, fvti.RegClass, isSEWAware>; + } +} + +multiclass VPatConversionVF_WF_BF_RM<string intrinsic, string instruction, + bit isSEWAware = 0> { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { + defvar fvti = fvtiToFWti.Vti; + defvar fwti = fvtiToFWti.Wti; + defm : VPatConversionRoundingMode<intrinsic, instruction, "W", + fvti.Vector, fwti.Vector, fvti.Mask, fvti.Log2SEW, + fvti.LMul, fvti.RegClass, fwti.RegClass, + isSEWAware>; + } +} + +let Predicates = [HasStdExtZvfbfmin] in { defm : VPatConversionWF_VF_BF<"int_riscv_vfwcvtbf16_f_f_v", "PseudoVFWCVTBF16_F_F", isSEWAware=1>; defm : VPatConversionVF_WF_BF_RM<"int_riscv_vfncvtbf16_f_f_w", @@ -56,7 +410,6 @@ let Predicates = [HasStdExtZvfbfminOrZvfofp8min] in { foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; - let Predicates = [HasVInstructionsBF16Minimal] in def : Pat<(fwti.Vector (any_riscv_fpextend_vl (fvti.Vector fvti.RegClass:$rs1), (fvti.Mask VMV0:$vm), @@ -66,18 +419,16 @@ let Predicates = [HasStdExtZvfbfminOrZvfofp8min] in { (fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW, TA_MA)>; - let Predicates = [HasVInstructionsBF16Minimal] in - def : Pat<(fvti.Vector (any_riscv_fpround_vl - (fwti.Vector fwti.RegClass:$rs1), - (fwti.Mask VMV0:$vm), VLOpFrag)), - (!cast<Instruction>("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") - (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, - (fwti.Mask VMV0:$vm), - // Value to indicate no rounding mode change in - // RISCVInsertReadWriteCSR - FRM_DYN, - GPR:$vl, fvti.Log2SEW, TA_MA)>; - let Predicates = [HasVInstructionsBF16Minimal] in + def : Pat<(fvti.Vector (any_riscv_fpround_vl + (fwti.Vector fwti.RegClass:$rs1), + (fwti.Mask VMV0:$vm), VLOpFrag)), + (!cast<Instruction>("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") + (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, + (fwti.Mask VMV0:$vm), + // Value to indicate no rounding mode change in + // RISCVInsertReadWriteCSR + FRM_DYN, + GPR:$vl, fvti.Log2SEW, TA_MA)>; def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))), (!cast<Instruction>("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW) (fvti.Vector (IMPLICIT_DEF)), @@ -87,6 +438,130 @@ let Predicates = [HasStdExtZvfbfminOrZvfofp8min] in { FRM_DYN, fvti.AVL, fvti.Log2SEW, TA_MA)>; } + + defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllBF16Vectors>; + defm : VPatBinaryV_VV_VX_VI_INT<"int_riscv_vrgather", "PseudoVRGATHER", + AllBF16Vectors, uimm5>; + defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16", + eew=16, vtilist=AllBF16Vectors>; + defm : VPatTernaryV_VX_VI<"int_riscv_vslideup", "PseudoVSLIDEUP", AllBF16Vectors, uimm5>; + defm : VPatTernaryV_VX_VI<"int_riscv_vslidedown", "PseudoVSLIDEDOWN", AllBF16Vectors, uimm5>; + + foreach fvti = AllBF16Vectors in { + defm : VPatBinaryCarryInTAIL<"int_riscv_vmerge", "PseudoVMERGE", "VVM", + fvti.Vector, + fvti.Vector, fvti.Vector, fvti.Mask, + fvti.Log2SEW, fvti.LMul, fvti.RegClass, + fvti.RegClass, fvti.RegClass>; + defm : VPatBinaryCarryInTAIL<"int_riscv_vfmerge", "PseudoVFMERGE", + "V"#fvti.ScalarSuffix#"M", + fvti.Vector, + fvti.Vector, fvti.Scalar, fvti.Mask, + fvti.Log2SEW, fvti.LMul, fvti.RegClass, + fvti.RegClass, fvti.ScalarRegClass>; + defvar instr = !cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX); + def : Pat<(fvti.Vector (int_riscv_vfmerge (fvti.Vector fvti.RegClass:$passthru), + (fvti.Vector fvti.RegClass:$rs2), + (fvti.Scalar (fpimm0)), + (fvti.Mask VMV0:$vm), VLOpFrag)), + (instr fvti.RegClass:$passthru, fvti.RegClass:$rs2, 0, + (fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW)>; + + defvar ivti = GetIntVTypeInfo<fvti>.Vti; + def : Pat<(fvti.Vector (vselect (fvti.Mask VMV0:$vm), fvti.RegClass:$rs1, + fvti.RegClass:$rs2)), + (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX) + (fvti.Vector (IMPLICIT_DEF)), + fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask VMV0:$vm), + fvti.AVL, fvti.Log2SEW)>; + + def : Pat<(fvti.Vector (vselect (fvti.Mask VMV0:$vm), + (SplatFPOp (SelectScalarFPAsInt (XLenVT GPR:$imm))), + fvti.RegClass:$rs2)), + (!cast<Instruction>("PseudoVMERGE_VXM_"#fvti.LMul.MX) + (fvti.Vector (IMPLICIT_DEF)), + fvti.RegClass:$rs2, GPR:$imm, (fvti.Mask VMV0:$vm), fvti.AVL, fvti.Log2SEW)>; + + def : Pat<(fvti.Vector (vselect (fvti.Mask VMV0:$vm), + (SplatFPOp (fvti.Scalar fpimm0)), + fvti.RegClass:$rs2)), + (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX) + (fvti.Vector (IMPLICIT_DEF)), + fvti.RegClass:$rs2, 0, (fvti.Mask VMV0:$vm), fvti.AVL, fvti.Log2SEW)>; + + def : Pat<(fvti.Vector (vselect (fvti.Mask VMV0:$vm), + (SplatFPOp fvti.ScalarRegClass:$rs1), + fvti.RegClass:$rs2)), + (!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX) + (fvti.Vector (IMPLICIT_DEF)), + fvti.RegClass:$rs2, + (fvti.Scalar fvti.ScalarRegClass:$rs1), + (fvti.Mask VMV0:$vm), fvti.AVL, fvti.Log2SEW)>; + + def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask VMV0:$vm), + fvti.RegClass:$rs1, + fvti.RegClass:$rs2, + fvti.RegClass:$passthru, + VLOpFrag)), + (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX) + fvti.RegClass:$passthru, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask VMV0:$vm), + GPR:$vl, fvti.Log2SEW)>; + + def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask VMV0:$vm), + (SplatFPOp (SelectScalarFPAsInt (XLenVT GPR:$imm))), + fvti.RegClass:$rs2, + fvti.RegClass:$passthru, + VLOpFrag)), + (!cast<Instruction>("PseudoVMERGE_VXM_"#fvti.LMul.MX) + fvti.RegClass:$passthru, fvti.RegClass:$rs2, GPR:$imm, (fvti.Mask VMV0:$vm), + GPR:$vl, fvti.Log2SEW)>; + + + def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask VMV0:$vm), + (SplatFPOp (fvti.Scalar fpimm0)), + fvti.RegClass:$rs2, + fvti.RegClass:$passthru, + VLOpFrag)), + (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX) + fvti.RegClass:$passthru, fvti.RegClass:$rs2, 0, (fvti.Mask VMV0:$vm), + GPR:$vl, fvti.Log2SEW)>; + + def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask VMV0:$vm), + (SplatFPOp fvti.ScalarRegClass:$rs1), + fvti.RegClass:$rs2, + fvti.RegClass:$passthru, + VLOpFrag)), + (!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX) + fvti.RegClass:$passthru, fvti.RegClass:$rs2, + (fvti.Scalar fvti.ScalarRegClass:$rs1), + (fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW)>; + + def : Pat<(fvti.Vector + (riscv_vrgather_vv_vl fvti.RegClass:$rs2, + (ivti.Vector fvti.RegClass:$rs1), + fvti.RegClass:$passthru, + (fvti.Mask VMV0:$vm), + VLOpFrag)), + (!cast<Instruction>("PseudoVRGATHER_VV_"# fvti.LMul.MX#"_E"# fvti.SEW#"_MASK") + fvti.RegClass:$passthru, fvti.RegClass:$rs2, fvti.RegClass:$rs1, + (fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW, TAIL_AGNOSTIC)>; + def : Pat<(fvti.Vector (riscv_vrgather_vx_vl fvti.RegClass:$rs2, GPR:$rs1, + fvti.RegClass:$passthru, + (fvti.Mask VMV0:$vm), + VLOpFrag)), + (!cast<Instruction>("PseudoVRGATHER_VX_"# fvti.LMul.MX#"_MASK") + fvti.RegClass:$passthru, fvti.RegClass:$rs2, GPR:$rs1, + (fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW, TAIL_AGNOSTIC)>; + def : Pat<(fvti.Vector + (riscv_vrgather_vx_vl fvti.RegClass:$rs2, + uimm5:$imm, + fvti.RegClass:$passthru, + (fvti.Mask VMV0:$vm), + VLOpFrag)), + (!cast<Instruction>("PseudoVRGATHER_VI_"# fvti.LMul.MX#"_MASK") + fvti.RegClass:$passthru, fvti.RegClass:$rs2, uimm5:$imm, + (fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW, TAIL_AGNOSTIC)>; + } } let Predicates = [HasStdExtZvfbfwma] in { @@ -97,3 +572,224 @@ let Predicates = [HasStdExtZvfbfwma] in { defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACCBF16", AllWidenableBF16ToFloatVectors>; } + +multiclass VPatConversionVI_VF_BF16<string intrinsic, string instruction> { + foreach fvti = AllBF16Vectors in { + defvar ivti = GetIntVTypeInfo<fvti>.Vti; + let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates, + GetVTypePredicates<ivti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "V", + ivti.Vector, fvti.Vector, ivti.Mask, fvti.Log2SEW, + fvti.LMul, ivti.RegClass, fvti.RegClass>; + } +} + +multiclass VPatConversionWF_VI_BF16<string intrinsic, string instruction, + bit isSEWAware = 0> { + foreach vtiToWti = AllWidenableIntToBF16Vectors in { + defvar vti = vtiToWti.Vti; + defvar fwti = vtiToWti.Wti; + let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, + GetVTypePredicates<fwti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "V", + fwti.Vector, vti.Vector, fwti.Mask, vti.Log2SEW, + vti.LMul, fwti.RegClass, vti.RegClass, isSEWAware>; + } +} + +multiclass VPatConversionWF_VF_BF16<string intrinsic, string instruction, + bit isSEWAware = 0> { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { + defvar fvti = fvtiToFWti.Vti; + defvar fwti = fvtiToFWti.Wti; + let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates, + GetVTypeMinimalPredicates<fwti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "V", + fwti.Vector, fvti.Vector, fwti.Mask, fvti.Log2SEW, + fvti.LMul, fwti.RegClass, fvti.RegClass, isSEWAware>; + } +} + +multiclass VPatConversionVI_WF_BF16<string intrinsic, string instruction> { + foreach vtiToWti = AllWidenableIntToBF16Vectors in { + defvar vti = vtiToWti.Vti; + defvar fwti = vtiToWti.Wti; + let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, + GetVTypePredicates<fwti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "W", + vti.Vector, fwti.Vector, vti.Mask, vti.Log2SEW, + vti.LMul, vti.RegClass, fwti.RegClass>; + } +} + +multiclass VPatConversionVI_WF_RM_BF16<string intrinsic, string instruction> { + foreach vtiToWti = AllWidenableIntToBF16Vectors in { + defvar vti = vtiToWti.Vti; + defvar fwti = vtiToWti.Wti; + let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, + GetVTypePredicates<fwti>.Predicates) in + defm : VPatConversionRoundingMode<intrinsic, instruction, "W", + vti.Vector, fwti.Vector, vti.Mask, vti.Log2SEW, + vti.LMul, vti.RegClass, fwti.RegClass>; + } +} + +multiclass VPatConversionVF_WF_BF16<string intrinsic, string instruction, + bit isSEWAware = 0> { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { + defvar fvti = fvtiToFWti.Vti; + defvar fwti = fvtiToFWti.Wti; + let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates, + GetVTypePredicates<fwti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "W", + fvti.Vector, fwti.Vector, fvti.Mask, fvti.Log2SEW, + fvti.LMul, fvti.RegClass, fwti.RegClass, isSEWAware>; + } +} + +let Predicates = [HasStdExtZvfbfa] in { +defm : VPatBinaryV_VV_VX_RM<"int_riscv_vfadd", "PseudoVFADD_ALT", + AllBF16Vectors, isSEWAware = 1>; +defm : VPatBinaryV_VV_VX_RM<"int_riscv_vfsub", "PseudoVFSUB_ALT", + AllBF16Vectors, isSEWAware = 1>; +defm : VPatBinaryV_VX_RM<"int_riscv_vfrsub", "PseudoVFRSUB_ALT", + AllBF16Vectors, isSEWAware = 1>; +defm : VPatBinaryW_VV_VX_RM<"int_riscv_vfwadd", "PseudoVFWADD_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatBinaryW_VV_VX_RM<"int_riscv_vfwsub", "PseudoVFWSUB_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatBinaryW_WV_WX_RM<"int_riscv_vfwadd_w", "PseudoVFWADD_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatBinaryW_WV_WX_RM<"int_riscv_vfwsub_w", "PseudoVFWSUB_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX_RM<"int_riscv_vfmul", "PseudoVFMUL_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryW_VV_VX_RM<"int_riscv_vfwmul", "PseudoVFWMUL_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfmacc", "PseudoVFMACC_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfnmacc", "PseudoVFNMACC_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfmsac", "PseudoVFMSAC_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfnmsac", "PseudoVFNMSAC_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfmadd", "PseudoVFMADD_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfnmadd", "PseudoVFNMADD_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfmsub", "PseudoVFMSUB_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfnmsub", "PseudoVFNMSUB_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwmacc", "PseudoVFWMACC_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwnmacc", "PseudoVFWNMACC_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwmsac", "PseudoVFWMSAC_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwnmsac", "PseudoVFWNMSAC_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatUnaryV_V<"int_riscv_vfrsqrt7", "PseudoVFRSQRT7_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatUnaryV_V_RM<"int_riscv_vfrec7", "PseudoVFREC7_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfmin", "PseudoVFMIN_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfmax", "PseudoVFMAX_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfsgnj", "PseudoVFSGNJ_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfsgnjn", "PseudoVFSGNJN_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfsgnjx", "PseudoVFSGNJX_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryM_VV_VX<"int_riscv_vmfeq", "PseudoVMFEQ_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VV_VX<"int_riscv_vmfle", "PseudoVMFLE_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VV_VX<"int_riscv_vmflt", "PseudoVMFLT_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VV_VX<"int_riscv_vmfne", "PseudoVMFNE_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VX<"int_riscv_vmfgt", "PseudoVMFGT_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VX<"int_riscv_vmfge", "PseudoVMFGE_ALT", AllBF16Vectors>; +defm : VPatBinarySwappedM_VV<"int_riscv_vmfgt", "PseudoVMFLT_ALT", AllBF16Vectors>; +defm : VPatBinarySwappedM_VV<"int_riscv_vmfge", "PseudoVMFLE_ALT", AllBF16Vectors>; +defm : VPatConversionVI_VF_BF16<"int_riscv_vfclass", "PseudoVFCLASS_ALT">; +foreach vti = AllBF16Vectors in { + let Predicates = GetVTypePredicates<vti>.Predicates in + defm : VPatBinaryCarryInTAIL<"int_riscv_vfmerge", "PseudoVFMERGE_ALT", + "V"#vti.ScalarSuffix#"M", + vti.Vector, + vti.Vector, vti.Scalar, vti.Mask, + vti.Log2SEW, vti.LMul, vti.RegClass, + vti.RegClass, vti.ScalarRegClass>; +} +defm : VPatConversionWF_VI_BF16<"int_riscv_vfwcvt_f_xu_v", "PseudoVFWCVT_F_XU_ALT", + isSEWAware=1>; +defm : VPatConversionWF_VI_BF16<"int_riscv_vfwcvt_f_x_v", "PseudoVFWCVT_F_X_ALT", + isSEWAware=1>; +defm : VPatConversionWF_VF_BF16<"int_riscv_vfwcvt_f_f_v", "PseudoVFWCVT_F_F_ALT", + isSEWAware=1>; +defm : VPatConversionVI_WF_RM_BF16<"int_riscv_vfncvt_xu_f_w", "PseudoVFNCVT_XU_F_ALT">; +defm : VPatConversionVI_WF_RM_BF16<"int_riscv_vfncvt_x_f_w", "PseudoVFNCVT_X_F_ALT">; +defm : VPatConversionVI_WF_BF16<"int_riscv_vfncvt_rtz_xu_f_w", "PseudoVFNCVT_RTZ_XU_F_ALT">; +defm : VPatConversionVI_WF_BF16<"int_riscv_vfncvt_rtz_x_f_w", "PseudoVFNCVT_RTZ_X_F_ALT">; +defm : VPatConversionVF_WF_RM<"int_riscv_vfncvt_f_f_w", "PseudoVFNCVT_F_F_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatConversionVF_WF_BF16<"int_riscv_vfncvt_rod_f_f_w", "PseudoVFNCVT_ROD_F_F_ALT", + isSEWAware=1>; +defm : VPatBinaryV_VX<"int_riscv_vfslide1up", "PseudoVFSLIDE1UP_ALT", AllBF16Vectors>; +defm : VPatBinaryV_VX<"int_riscv_vfslide1down", "PseudoVFSLIDE1DOWN_ALT", AllBF16Vectors>; + +foreach fvti = AllBF16Vectors in { + defvar ivti = GetIntVTypeInfo<fvti>.Vti; + let Predicates = GetVTypePredicates<ivti>.Predicates in { + // 13.16. Vector Floating-Point Move Instruction + // If we're splatting fpimm0, use vmv.v.x vd, x0. + def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl + fvti.Vector:$passthru, (fvti.Scalar (fpimm0)), VLOpFrag)), + (!cast<Instruction>("PseudoVMV_V_I_"#fvti.LMul.MX) + $passthru, 0, GPR:$vl, fvti.Log2SEW, TU_MU)>; + def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl + fvti.Vector:$passthru, (fvti.Scalar (SelectScalarFPAsInt (XLenVT GPR:$imm))), VLOpFrag)), + (!cast<Instruction>("PseudoVMV_V_X_"#fvti.LMul.MX) + $passthru, GPR:$imm, GPR:$vl, fvti.Log2SEW, TU_MU)>; + } + + let Predicates = GetVTypePredicates<fvti>.Predicates in { + def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl + fvti.Vector:$passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), VLOpFrag)), + (!cast<Instruction>("PseudoVFMV_V_ALT_" # fvti.ScalarSuffix # "_" # + fvti.LMul.MX) + $passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), + GPR:$vl, fvti.Log2SEW, TU_MU)>; + } +} + +foreach vti = NoGroupBF16Vectors in { + let Predicates = GetVTypePredicates<vti>.Predicates in { + def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), + (vti.Scalar (fpimm0)), + VLOpFrag)), + (PseudoVMV_S_X $passthru, (XLenVT X0), GPR:$vl, vti.Log2SEW)>; + def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), + (vti.Scalar (SelectScalarFPAsInt (XLenVT GPR:$imm))), + VLOpFrag)), + (PseudoVMV_S_X $passthru, GPR:$imm, GPR:$vl, vti.Log2SEW)>; + def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), + vti.ScalarRegClass:$rs1, + VLOpFrag)), + (!cast<Instruction>("PseudoVFMV_S_"#vti.ScalarSuffix#"_ALT") + vti.RegClass:$passthru, + (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>; + } + + defvar vfmv_f_s_inst = !cast<Instruction>(!strconcat("PseudoVFMV_", + vti.ScalarSuffix, + "_S_ALT")); + // Only pattern-match extract-element operations where the index is 0. Any + // other index will have been custom-lowered to slide the vector correctly + // into place. + let Predicates = GetVTypePredicates<vti>.Predicates in + def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), 0)), + (vfmv_f_s_inst vti.RegClass:$rs2, vti.Log2SEW)>; +} +} // Predicates = [HasStdExtZvfbfa] diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h index 6acf799..334db4b 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.h +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -288,9 +288,12 @@ public: bool hasVInstructionsI64() const { return HasStdExtZve64x; } bool hasVInstructionsF16Minimal() const { return HasStdExtZvfhmin; } bool hasVInstructionsF16() const { return HasStdExtZvfh; } - bool hasVInstructionsBF16Minimal() const { return HasStdExtZvfbfmin; } + bool hasVInstructionsBF16Minimal() const { + return HasStdExtZvfbfmin || HasStdExtZvfbfa; + } bool hasVInstructionsF32() const { return HasStdExtZve32f; } bool hasVInstructionsF64() const { return HasStdExtZve64d; } + bool hasVInstructionsBF16() const { return HasStdExtZvfbfa; } // F16 and F64 both require F32. bool hasVInstructionsAnyF() const { return hasVInstructionsF32(); } bool hasVInstructionsFullMultiply() const { return HasStdExtV; } diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 56a6168..640b014 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -78,6 +78,8 @@ public: void outputExecutionModeFromNumthreadsAttribute( const MCRegister &Reg, const Attribute &Attr, SPIRV::ExecutionMode::ExecutionMode EM); + void outputExecutionModeFromEnableMaximalReconvergenceAttr( + const MCRegister &Reg, const SPIRVSubtarget &ST); void outputExecutionMode(const Module &M); void outputAnnotations(const Module &M); void outputModuleSections(); @@ -495,6 +497,20 @@ void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute( outputMCInst(Inst); } +void SPIRVAsmPrinter::outputExecutionModeFromEnableMaximalReconvergenceAttr( + const MCRegister &Reg, const SPIRVSubtarget &ST) { + assert(ST.canUseExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence) && + "Function called when SPV_KHR_maximal_reconvergence is not enabled."); + + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + Inst.addOperand(MCOperand::createReg(Reg)); + unsigned EM = + static_cast<unsigned>(SPIRV::ExecutionMode::MaximallyReconvergesKHR); + Inst.addOperand(MCOperand::createImm(EM)); + outputMCInst(Inst); +} + void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { @@ -551,6 +567,10 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { if (Attribute Attr = F.getFnAttribute("hlsl.numthreads"); Attr.isValid()) outputExecutionModeFromNumthreadsAttribute( FReg, Attr, SPIRV::ExecutionMode::LocalSize); + if (Attribute Attr = F.getFnAttribute("enable-maximal-reconvergence"); + Attr.getValueAsBool()) { + outputExecutionModeFromEnableMaximalReconvergenceAttr(FReg, *ST); + } if (MDNode *Node = F.getMetadata("work_group_size_hint")) outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSizeHint, 3, 1); diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 5f3ed86..96f5dee 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -153,7 +153,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>> SPIRV::Extension::Extension:: SPV_EXT_relaxed_printf_string_address_space}, {"SPV_INTEL_predicated_io", - SPIRV::Extension::Extension::SPV_INTEL_predicated_io}}; + SPIRV::Extension::Extension::SPV_INTEL_predicated_io}, + {"SPV_KHR_maximal_reconvergence", + SPIRV::Extension::Extension::SPV_KHR_maximal_reconvergence}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index c6c6182..a151fd2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -1392,19 +1392,19 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) { Constant *AggrConst = nullptr; Type *ResTy = nullptr; if (auto *COp = dyn_cast<ConstantVector>(Op)) { - AggrConst = cast<Constant>(COp); + AggrConst = COp; ResTy = COp->getType(); } else if (auto *COp = dyn_cast<ConstantArray>(Op)) { - AggrConst = cast<Constant>(COp); + AggrConst = COp; ResTy = B.getInt32Ty(); } else if (auto *COp = dyn_cast<ConstantStruct>(Op)) { - AggrConst = cast<Constant>(COp); + AggrConst = COp; ResTy = B.getInt32Ty(); } else if (auto *COp = dyn_cast<ConstantDataArray>(Op)) { - AggrConst = cast<Constant>(COp); + AggrConst = COp; ResTy = B.getInt32Ty(); } else if (auto *COp = dyn_cast<ConstantAggregateZero>(Op)) { - AggrConst = cast<Constant>(COp); + AggrConst = COp; ResTy = Op->getType()->isVectorTy() ? COp->getType() : B.getInt32Ty(); } if (AggrConst) { diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 5144fb1..61a0bbe 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1200,6 +1200,23 @@ void addOpAccessChainReqs(const MachineInstr &Instr, return; } + bool IsNonUniform = + hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI); + + auto FirstIndexReg = Instr.getOperand(3).getReg(); + bool FirstIndexIsConstant = + Subtarget.getInstrInfo()->isConstantInstr(*MRI.getVRegDef(FirstIndexReg)); + + if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) { + if (IsNonUniform) + Handler.addRequirements( + SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT); + else if (!FirstIndexIsConstant) + Handler.addRequirements( + SPIRV::Capability::StorageBufferArrayDynamicIndexing); + return; + } + Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg(); MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg); if (PointeeType->getOpcode() != SPIRV::OpTypeImage && @@ -1208,27 +1225,25 @@ void addOpAccessChainReqs(const MachineInstr &Instr, return; } - bool IsNonUniform = - hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI); if (isUniformTexelBuffer(PointeeType)) { if (IsNonUniform) Handler.addRequirements( SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT); - else + else if (!FirstIndexIsConstant) Handler.addRequirements( SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT); } else if (isInputAttachment(PointeeType)) { if (IsNonUniform) Handler.addRequirements( SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT); - else + else if (!FirstIndexIsConstant) Handler.addRequirements( SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT); } else if (isStorageTexelBuffer(PointeeType)) { if (IsNonUniform) Handler.addRequirements( SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT); - else + else if (!FirstIndexIsConstant) Handler.addRequirements( SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT); } else if (isSampledImage(PointeeType) || @@ -1237,14 +1252,14 @@ void addOpAccessChainReqs(const MachineInstr &Instr, if (IsNonUniform) Handler.addRequirements( SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT); - else + else if (!FirstIndexIsConstant) Handler.addRequirements( SPIRV::Capability::SampledImageArrayDynamicIndexing); } else if (isStorageImage(PointeeType)) { if (IsNonUniform) Handler.addRequirements( SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT); - else + else if (!FirstIndexIsConstant) Handler.addRequirements( SPIRV::Capability::StorageImageArrayDynamicIndexing); } @@ -2155,6 +2170,9 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, SPIRV::OperandCategory::ExecutionModeOperand, SPIRV::ExecutionMode::LocalSize, ST); } + if (F.getFnAttribute("enable-maximal-reconvergence").getValueAsBool()) { + MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence); + } if (F.getMetadata("work_group_size_hint")) MAI.Reqs.getAndAddRequirements( SPIRV::OperandCategory::ExecutionModeOperand, diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 2625642..7d08b29 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -386,6 +386,7 @@ defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>; defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>; defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>; defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>; +defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -698,7 +699,7 @@ defm IntersectionNV: ExecutionModelOperand<5314, [RayTracingNV]>; defm AnyHitNV: ExecutionModelOperand<5315, [RayTracingNV]>; defm ClosestHitNV: ExecutionModelOperand<5316, [RayTracingNV]>; defm MissNV: ExecutionModelOperand<5317, [RayTracingNV]>; -defm CallableNV: ExecutionModelOperand<5318, [RayTracingNV]>; +defm CallableNV : ExecutionModelOperand<5318, [RayTracingNV]>; //===----------------------------------------------------------------------===// // Multiclass used to define MemoryModel enum values and at the same time @@ -805,6 +806,7 @@ defm RoundingModeRTNINTEL : ExecutionModeOperand<5621, [RoundToInfinityINTEL]>; defm FloatingPointModeALTINTEL : ExecutionModeOperand<5622, [FloatingPointModeINTEL]>; defm FloatingPointModeIEEEINTEL : ExecutionModeOperand<5623, [FloatingPointModeINTEL]>; defm FPFastMathDefault : ExecutionModeOperand<6028, [FloatControls2]>; +defm MaximallyReconvergesKHR : ExecutionModeOperand<6023, [Shader]>; //===----------------------------------------------------------------------===// // Multiclass used to define StorageClass enum values and at the same time diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index a0b64ff..b05d7c7 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -29755,65 +29755,30 @@ static SDValue LowervXi8MulWithUNPCK(SDValue A, SDValue B, const SDLoc &dl, const X86Subtarget &Subtarget, SelectionDAG &DAG, SDValue *Low = nullptr) { - unsigned NumElts = VT.getVectorNumElements(); - // For vXi8 we will unpack the low and high half of each 128 bit lane to widen // to a vXi16 type. Do the multiplies, shift the results and pack the half // lane results back together. // We'll take different approaches for signed and unsigned. - // For unsigned we'll use punpcklbw/punpckhbw to put zero extend the bytes - // and use pmullw to calculate the full 16-bit product. + // For unsigned we'll use punpcklbw/punpckhbw to zero extend the bytes to + // words and use pmullw to calculate the full 16-bit product. // For signed we'll use punpcklbw/punpckbw to extend the bytes to words and // shift them left into the upper byte of each word. This allows us to use // pmulhw to calculate the full 16-bit product. This trick means we don't // need to sign extend the bytes to use pmullw. - - MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2); + MVT ExVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2); SDValue Zero = DAG.getConstant(0, dl, VT); - SDValue ALo, AHi; + SDValue ALo, AHi, BLo, BHi; if (IsSigned) { ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, A)); - AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, A)); - } else { - ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Zero)); - AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Zero)); - } - - SDValue BLo, BHi; - if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) { - // If the RHS is a constant, manually unpackl/unpackh and extend. - SmallVector<SDValue, 16> LoOps, HiOps; - for (unsigned i = 0; i != NumElts; i += 16) { - for (unsigned j = 0; j != 8; ++j) { - SDValue LoOp = B.getOperand(i + j); - SDValue HiOp = B.getOperand(i + j + 8); - - if (IsSigned) { - LoOp = DAG.getAnyExtOrTrunc(LoOp, dl, MVT::i16); - HiOp = DAG.getAnyExtOrTrunc(HiOp, dl, MVT::i16); - LoOp = DAG.getNode(ISD::SHL, dl, MVT::i16, LoOp, - DAG.getConstant(8, dl, MVT::i16)); - HiOp = DAG.getNode(ISD::SHL, dl, MVT::i16, HiOp, - DAG.getConstant(8, dl, MVT::i16)); - } else { - LoOp = DAG.getZExtOrTrunc(LoOp, dl, MVT::i16); - HiOp = DAG.getZExtOrTrunc(HiOp, dl, MVT::i16); - } - - LoOps.push_back(LoOp); - HiOps.push_back(HiOp); - } - } - - BLo = DAG.getBuildVector(ExVT, dl, LoOps); - BHi = DAG.getBuildVector(ExVT, dl, HiOps); - } else if (IsSigned) { BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, B)); + AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, A)); BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, B)); } else { + ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Zero)); BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, Zero)); + AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Zero)); BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, Zero)); } @@ -29826,7 +29791,7 @@ static SDValue LowervXi8MulWithUNPCK(SDValue A, SDValue B, const SDLoc &dl, if (Low) *Low = getPack(DAG, Subtarget, dl, VT, RLo, RHi); - return getPack(DAG, Subtarget, dl, VT, RLo, RHi, /*PackHiHalf*/ true); + return getPack(DAG, Subtarget, dl, VT, RLo, RHi, /*PackHiHalf=*/true); } static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, @@ -44848,10 +44813,16 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( } case X86ISD::PCMPGT: // icmp sgt(0, R) == ashr(R, BitWidth-1). - // iff we only need the sign bit then we can use R directly. - if (OriginalDemandedBits.isSignMask() && - ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) - return TLO.CombineTo(Op, Op.getOperand(1)); + if (ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) { + // iff we only need the signbit then we can use R directly. + if (OriginalDemandedBits.isSignMask()) + return TLO.CombineTo(Op, Op.getOperand(1)); + // otherwise we just need R's signbit for the comparison. + APInt SignMask = APInt::getSignMask(BitWidth); + if (SimplifyDemandedBits(Op.getOperand(1), SignMask, OriginalDemandedElts, + Known, TLO, Depth + 1)) + return true; + } break; case X86ISD::MOVMSK: { SDValue Src = Op.getOperand(0); @@ -47761,6 +47732,15 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, DL, DAG, Subtarget)) return V; + // If the sign bit is known then BLENDV can be folded away. + if (N->getOpcode() == X86ISD::BLENDV) { + KnownBits KnownCond = DAG.computeKnownBits(Cond); + if (KnownCond.isNegative()) + return LHS; + if (KnownCond.isNonNegative()) + return RHS; + } + if (N->getOpcode() == ISD::VSELECT || N->getOpcode() == X86ISD::BLENDV) { SmallVector<int, 64> CondMask; if (createShuffleMaskFromVSELECT(CondMask, Cond, diff --git a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp index 6dd43b2..37d7772 100644 --- a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp +++ b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp @@ -606,16 +606,24 @@ Value *X86TargetLowering::getIRStackGuard(IRBuilderBase &IRB) const { void X86TargetLowering::insertSSPDeclarations(Module &M) const { // MSVC CRT provides functionalities for stack protection. - if (Subtarget.getTargetTriple().isWindowsMSVCEnvironment() || - Subtarget.getTargetTriple().isWindowsItaniumEnvironment()) { + RTLIB::LibcallImpl SecurityCheckCookieLibcall = + getLibcallImpl(RTLIB::SECURITY_CHECK_COOKIE); + + RTLIB::LibcallImpl SecurityCookieVar = + getLibcallImpl(RTLIB::STACK_CHECK_GUARD); + if (SecurityCheckCookieLibcall != RTLIB::Unsupported && + SecurityCookieVar != RTLIB::Unsupported) { + // MSVC CRT provides functionalities for stack protection. // MSVC CRT has a global variable holding security cookie. - M.getOrInsertGlobal("__security_cookie", + M.getOrInsertGlobal(getLibcallImplName(SecurityCookieVar), PointerType::getUnqual(M.getContext())); // MSVC CRT has a function to validate security cookie. - FunctionCallee SecurityCheckCookie = M.getOrInsertFunction( - "__security_check_cookie", Type::getVoidTy(M.getContext()), - PointerType::getUnqual(M.getContext())); + FunctionCallee SecurityCheckCookie = + M.getOrInsertFunction(getLibcallImplName(SecurityCheckCookieLibcall), + Type::getVoidTy(M.getContext()), + PointerType::getUnqual(M.getContext())); + if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) { F->setCallingConv(CallingConv::X86_FastCall); F->addParamAttr(0, Attribute::AttrKind::InReg); diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index 1d2cd39..5c23f91 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -10809,39 +10809,27 @@ void X86InstrInfo::buildClearRegister(Register Reg, MachineBasicBlock &MBB, if (!ST.hasSSE1()) return; - // PXOR is safe to use because it doesn't affect flags. - BuildMI(MBB, Iter, DL, get(X86::PXORrr), Reg) - .addReg(Reg, RegState::Undef) - .addReg(Reg, RegState::Undef); + BuildMI(MBB, Iter, DL, get(X86::V_SET0), Reg); } else if (X86::VR256RegClass.contains(Reg)) { // YMM# if (!ST.hasAVX()) return; - // VPXOR is safe to use because it doesn't affect flags. - BuildMI(MBB, Iter, DL, get(X86::VPXORrr), Reg) - .addReg(Reg, RegState::Undef) - .addReg(Reg, RegState::Undef); + BuildMI(MBB, Iter, DL, get(X86::AVX_SET0), Reg); } else if (X86::VR512RegClass.contains(Reg)) { // ZMM# if (!ST.hasAVX512()) return; - // VPXORY is safe to use because it doesn't affect flags. - BuildMI(MBB, Iter, DL, get(X86::VPXORYrr), Reg) - .addReg(Reg, RegState::Undef) - .addReg(Reg, RegState::Undef); + BuildMI(MBB, Iter, DL, get(X86::AVX512_512_SET0), Reg); } else if (X86::VK1RegClass.contains(Reg) || X86::VK2RegClass.contains(Reg) || X86::VK4RegClass.contains(Reg) || X86::VK8RegClass.contains(Reg) || X86::VK16RegClass.contains(Reg)) { if (!ST.hasVLX()) return; - // KXOR is safe to use because it doesn't affect flags. - unsigned Op = ST.hasBWI() ? X86::KXORQkk : X86::KXORWkk; - BuildMI(MBB, Iter, DL, get(Op), Reg) - .addReg(Reg, RegState::Undef) - .addReg(Reg, RegState::Undef); + unsigned Op = ST.hasBWI() ? X86::KSET0Q : X86::KSET0W; + BuildMI(MBB, Iter, DL, get(Op), Reg); } } diff --git a/llvm/lib/Target/X86/X86MCInstLower.cpp b/llvm/lib/Target/X86/X86MCInstLower.cpp index 1fca466f..713d504 100644 --- a/llvm/lib/Target/X86/X86MCInstLower.cpp +++ b/llvm/lib/Target/X86/X86MCInstLower.cpp @@ -1928,6 +1928,17 @@ static void addConstantComments(const MachineInstr *MI, #define INSTR_CASE(Prefix, Instr, Suffix, Postfix) \ case X86::Prefix##Instr##Suffix##rm##Postfix: +#define CASE_AVX512_ARITH_RM(Instr) \ + INSTR_CASE(V, Instr, Z128, ) \ + INSTR_CASE(V, Instr, Z128, k) \ + INSTR_CASE(V, Instr, Z128, kz) \ + INSTR_CASE(V, Instr, Z256, ) \ + INSTR_CASE(V, Instr, Z256, k) \ + INSTR_CASE(V, Instr, Z256, kz) \ + INSTR_CASE(V, Instr, Z, ) \ + INSTR_CASE(V, Instr, Z, k) \ + INSTR_CASE(V, Instr, Z, kz) + #define CASE_ARITH_RM(Instr) \ INSTR_CASE(, Instr, , ) /* SSE */ \ INSTR_CASE(V, Instr, , ) /* AVX-128 */ \ @@ -1943,22 +1954,12 @@ static void addConstantComments(const MachineInstr *MI, INSTR_CASE(V, Instr, Z, kz) // TODO: Add additional instructions when useful. - CASE_ARITH_RM(PMADDUBSW) { - unsigned SrcIdx = getSrcIdx(MI, 1); - if (auto *C = X86::getConstantFromPool(*MI, SrcIdx + 1)) { - std::string Comment; - raw_string_ostream CS(Comment); - unsigned VectorWidth = - X86::getVectorRegisterWidth(MI->getDesc().operands()[0]); - CS << "["; - printConstant(C, VectorWidth, CS); - CS << "]"; - OutStreamer.AddComment(CS.str()); - } - break; - } - + CASE_ARITH_RM(PMADDUBSW) CASE_ARITH_RM(PMADDWD) + CASE_ARITH_RM(PMULDQ) + CASE_ARITH_RM(PMULUDQ) + CASE_ARITH_RM(PMULLD) + CASE_AVX512_ARITH_RM(PMULLQ) CASE_ARITH_RM(PMULLW) CASE_ARITH_RM(PMULHW) CASE_ARITH_RM(PMULHUW) diff --git a/llvm/lib/TargetParser/RISCVISAInfo.cpp b/llvm/lib/TargetParser/RISCVISAInfo.cpp index 9268df2..31126cc 100644 --- a/llvm/lib/TargetParser/RISCVISAInfo.cpp +++ b/llvm/lib/TargetParser/RISCVISAInfo.cpp @@ -887,7 +887,7 @@ void RISCVISAInfo::updateImplication() { } static constexpr StringLiteral CombineIntoExts[] = { - {"b"}, {"zk"}, {"zkn"}, {"zks"}, {"zvkn"}, + {"a"}, {"b"}, {"zk"}, {"zkn"}, {"zks"}, {"zvkn"}, {"zvknc"}, {"zvkng"}, {"zvks"}, {"zvksc"}, {"zvksg"}, }; diff --git a/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/llvm/lib/Transforms/CFGuard/CFGuard.cpp index b73a0ce..4645670 100644 --- a/llvm/lib/Transforms/CFGuard/CFGuard.cpp +++ b/llvm/lib/Transforms/CFGuard/CFGuard.cpp @@ -147,7 +147,7 @@ public: private: // Only add checks if the module has the cfguard=2 flag. - int cfguard_module_flag = 0; + int CFGuardModuleFlag = 0; StringRef GuardFnName; Mechanism GuardMechanism = Mechanism::Check; FunctionType *GuardFnType = nullptr; @@ -162,9 +162,7 @@ public: static char ID; // Default constructor required for the INITIALIZE_PASS macro. - CFGuard(CFGuardImpl::Mechanism M) : FunctionPass(ID), Impl(M) { - initializeCFGuardPass(*PassRegistry::getPassRegistry()); - } + CFGuard(CFGuardImpl::Mechanism M) : FunctionPass(ID), Impl(M) {} bool doInitialization(Module &M) override { return Impl.doInitialization(M); } bool runOnFunction(Function &F) override { return Impl.runOnFunction(F); } @@ -173,7 +171,6 @@ public: } // end anonymous namespace void CFGuardImpl::insertCFGuardCheck(CallBase *CB) { - assert(CB->getModule()->getTargetTriple().isOSWindows() && "Only applicable for Windows targets"); assert(CB->isIndirectCall() && @@ -202,7 +199,6 @@ void CFGuardImpl::insertCFGuardCheck(CallBase *CB) { } void CFGuardImpl::insertCFGuardDispatch(CallBase *CB) { - assert(CB->getModule()->getTargetTriple().isOSWindows() && "Only applicable for Windows targets"); assert(CB->isIndirectCall() && @@ -236,14 +232,13 @@ void CFGuardImpl::insertCFGuardDispatch(CallBase *CB) { } bool CFGuardImpl::doInitialization(Module &M) { - // Check if this module has the cfguard flag and read its value. if (auto *MD = mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("cfguard"))) - cfguard_module_flag = MD->getZExtValue(); + CFGuardModuleFlag = MD->getZExtValue(); // Skip modules for which CFGuard checks have been disabled. - if (cfguard_module_flag != 2) + if (CFGuardModuleFlag != 2) return false; // Set up prototypes for the guard check and dispatch functions. @@ -264,9 +259,8 @@ bool CFGuardImpl::doInitialization(Module &M) { } bool CFGuardImpl::runOnFunction(Function &F) { - // Skip modules for which CFGuard checks have been disabled. - if (cfguard_module_flag != 2) + if (CFGuardModuleFlag != 2) return false; SmallVector<CallBase *, 8> IndirectCalls; @@ -286,19 +280,16 @@ bool CFGuardImpl::runOnFunction(Function &F) { } // If no checks are needed, return early. - if (IndirectCalls.empty()) { + if (IndirectCalls.empty()) return false; - } // For each indirect call/invoke, add the appropriate dispatch or check. if (GuardMechanism == Mechanism::Dispatch) { - for (CallBase *CB : IndirectCalls) { + for (CallBase *CB : IndirectCalls) insertCFGuardDispatch(CB); - } } else { - for (CallBase *CB : IndirectCalls) { + for (CallBase *CB : IndirectCalls) insertCFGuardCheck(CB); - } } return true; diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp index 5066a99..894d83f 100644 --- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp +++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -6150,3 +6150,42 @@ void MemProfContextDisambiguation::run( IndexCallsiteContextGraph CCG(Index, isPrevailing); CCG.process(); } + +// Strips MemProf attributes and metadata. Can be invoked by the pass pipeline +// when we don't have an index that has recorded that we are linking with +// allocation libraries containing the necessary APIs for downstream +// transformations. +PreservedAnalyses MemProfRemoveInfo::run(Module &M, ModuleAnalysisManager &AM) { + // The profile matcher applies hotness attributes directly for allocations, + // and those will cause us to generate calls to the hot/cold interfaces + // unconditionally. If supports-hot-cold-new was not enabled in the LTO + // link then assume we don't want these calls (e.g. not linking with + // the appropriate library, or otherwise trying to disable this behavior). + bool Changed = false; + for (auto &F : M) { + for (auto &BB : F) { + for (auto &I : BB) { + auto *CI = dyn_cast<CallBase>(&I); + if (!CI) + continue; + if (CI->hasFnAttr("memprof")) { + CI->removeFnAttr("memprof"); + Changed = true; + } + if (!CI->hasMetadata(LLVMContext::MD_callsite)) { + assert(!CI->hasMetadata(LLVMContext::MD_memprof)); + continue; + } + // Strip off all memprof metadata as it is no longer needed. + // Importantly, this avoids the addition of new memprof attributes + // after inlining propagation. + CI->setMetadata(LLVMContext::MD_memprof, nullptr); + CI->setMetadata(LLVMContext::MD_callsite, nullptr); + Changed = true; + } + } + } + if (!Changed) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 07ad65c..fba1ccf 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1481,13 +1481,13 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2())); } - if (Cmp.isEquality() && Trunc->hasOneUse()) { + if (Cmp.isEquality() && (Trunc->hasOneUse() || Trunc->hasNoUnsignedWrap())) { // Canonicalize to a mask and wider compare if the wide type is suitable: // (trunc X to i8) == C --> (X & 0xff) == (zext C) if (!SrcTy->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { Constant *Mask = ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcBits, DstBits)); - Value *And = Builder.CreateAnd(X, Mask); + Value *And = Trunc->hasNoUnsignedWrap() ? X : Builder.CreateAnd(X, Mask); Constant *WideC = ConstantInt::get(SrcTy, C.zext(SrcBits)); return new ICmpInst(Pred, And, WideC); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 09cb225..a8eb9b9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3757,6 +3757,10 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder, // (x < y) ? -1 : zext(x > y) // (x > y) ? 1 : sext(x != y) // (x > y) ? 1 : sext(x < y) +// (x == y) ? 0 : (x > y ? 1 : -1) +// (x == y) ? 0 : (x < y ? -1 : 1) +// Special case: x == C ? 0 : (x > C - 1 ? 1 : -1) +// Special case: x == C ? 0 : (x < C + 1 ? -1 : 1) // Into ucmp/scmp(x, y), where signedness is determined by the signedness // of the comparison in the original sequence. Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { @@ -3849,6 +3853,44 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { } } + // Special cases with constants: x == C ? 0 : (x > C-1 ? 1 : -1) + if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero())) { + const APInt *C; + if (match(RHS, m_APInt(C))) { + CmpPredicate InnerPred; + Value *InnerRHS; + const APInt *InnerTV, *InnerFV; + if (match(FV, + m_Select(m_ICmp(InnerPred, m_Specific(LHS), m_Value(InnerRHS)), + m_APInt(InnerTV), m_APInt(InnerFV)))) { + + // x == C ? 0 : (x > C-1 ? 1 : -1) + if (ICmpInst::isGT(InnerPred) && InnerTV->isOne() && + InnerFV->isAllOnes()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanSubOne = IsSigned ? !C->isMinSignedValue() : !C->isMinValue(); + if (CanSubOne) { + APInt Cminus1 = *C - 1; + if (match(InnerRHS, m_SpecificInt(Cminus1))) + Replace = true; + } + } + + // x == C ? 0 : (x < C+1 ? -1 : 1) + if (ICmpInst::isLT(InnerPred) && InnerTV->isAllOnes() && + InnerFV->isOne()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanAddOne = IsSigned ? !C->isMaxSignedValue() : !C->isMaxValue(); + if (CanAddOne) { + APInt Cplus1 = *C + 1; + if (match(InnerRHS, m_SpecificInt(Cplus1))) + Replace = true; + } + } + } + } + } + Intrinsic::ID IID = IsSigned ? Intrinsic::scmp : Intrinsic::ucmp; if (Replace) return replaceInstUsesWith( diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 6e17801..2646334 100644 --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -844,6 +844,7 @@ struct AddressSanitizer { bool maybeInsertAsanInitAtFunctionEntry(Function &F); bool maybeInsertDynamicShadowAtFunctionEntry(Function &F); void markEscapedLocalAllocas(Function &F); + void markCatchParametersAsUninteresting(Function &F); private: friend struct FunctionStackPoisoner; @@ -2997,6 +2998,22 @@ void AddressSanitizer::markEscapedLocalAllocas(Function &F) { } } } +// Mitigation for https://github.com/google/sanitizers/issues/749 +// We don't instrument Windows catch-block parameters to avoid +// interfering with exception handling assumptions. +void AddressSanitizer::markCatchParametersAsUninteresting(Function &F) { + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (auto *CatchPad = dyn_cast<CatchPadInst>(&I)) { + // Mark the parameters to a catch-block as uninteresting to avoid + // instrumenting them. + for (Value *Operand : CatchPad->arg_operands()) + if (auto *AI = dyn_cast<AllocaInst>(Operand)) + ProcessedAllocas[AI] = false; + } + } + } +} bool AddressSanitizer::suppressInstrumentationSiteForDebug(int &Instrumented) { bool ShouldInstrument = @@ -3041,6 +3058,9 @@ bool AddressSanitizer::instrumentFunction(Function &F, // can be passed to that intrinsic. markEscapedLocalAllocas(F); + if (TargetTriple.isOSWindows()) + markCatchParametersAsUninteresting(F); + // We want to instrument every address only once per basic block (unless there // are calls between uses). SmallPtrSet<Value *, 16> TempsToInstrument; diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp index 20733032..19eccb9 100644 --- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -368,7 +368,7 @@ private: Valid = false; } - bool reportInvalidCandidate(llvm::Statistic &Stat) const { + bool reportInvalidCandidate(Statistic &Stat) const { using namespace ore; assert(L && Preheader && "Fusion candidate not initialized properly!"); #if LLVM_ENABLE_STATS @@ -445,6 +445,7 @@ struct FusionCandidateCompare { "No dominance relationship between these fusion candidates!"); } }; +} // namespace using LoopVector = SmallVector<Loop *, 4>; @@ -461,9 +462,15 @@ using LoopVector = SmallVector<Loop *, 4>; using FusionCandidateSet = std::set<FusionCandidate, FusionCandidateCompare>; using FusionCandidateCollection = SmallVector<FusionCandidateSet, 4>; -#if !defined(NDEBUG) -static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, - const FusionCandidate &FC) { +#ifndef NDEBUG +static void printLoopVector(const LoopVector &LV) { + dbgs() << "****************************\n"; + for (const Loop *L : LV) + printLoop(*L, dbgs()); + dbgs() << "****************************\n"; +} + +static raw_ostream &operator<<(raw_ostream &OS, const FusionCandidate &FC) { if (FC.isValid()) OS << FC.Preheader->getName(); else @@ -472,8 +479,8 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, return OS; } -static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, - const FusionCandidateSet &CandSet) { +static raw_ostream &operator<<(raw_ostream &OS, + const FusionCandidateSet &CandSet) { for (const FusionCandidate &FC : CandSet) OS << FC << '\n'; @@ -489,7 +496,9 @@ printFusionCandidates(const FusionCandidateCollection &FusionCandidates) { dbgs() << "****************************\n"; } } -#endif +#endif // NDEBUG + +namespace { /// Collect all loops in function at the same nest level, starting at the /// outermost level. @@ -550,15 +559,6 @@ private: LoopsOnLevelTy LoopsOnLevel; }; -#ifndef NDEBUG -static void printLoopVector(const LoopVector &LV) { - dbgs() << "****************************\n"; - for (auto *L : LV) - printLoop(*L, dbgs()); - dbgs() << "****************************\n"; -} -#endif - struct LoopFuser { private: // Sets of control flow equivalent fusion candidates for a given nest level. @@ -1850,7 +1850,7 @@ private: /// <Cand1 Preheader> and <Cand2 Preheader>: <Stat Description> template <typename RemarkKind> void reportLoopFusion(const FusionCandidate &FC0, const FusionCandidate &FC1, - llvm::Statistic &Stat) { + Statistic &Stat) { assert(FC0.Preheader && FC1.Preheader && "Expecting valid fusion candidates"); using namespace ore; diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index 7da8586..d827e64 100644 --- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -8,7 +8,6 @@ #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -217,9 +216,6 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, // Get the analysis results needed by loop passes. MemorySSA *MSSA = UseMemorySSA ? (&AM.getResult<MemorySSAAnalysis>(F).getMSSA()) : nullptr; - BlockFrequencyInfo *BFI = UseBlockFrequencyInfo && F.hasProfileData() - ? (&AM.getResult<BlockFrequencyAnalysis>(F)) - : nullptr; LoopStandardAnalysisResults LAR = {AM.getResult<AAManager>(F), AM.getResult<AssumptionAnalysis>(F), AM.getResult<DominatorTreeAnalysis>(F), @@ -227,7 +223,6 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, AM.getResult<ScalarEvolutionAnalysis>(F), AM.getResult<TargetLibraryAnalysis>(F), AM.getResult<TargetIRAnalysis>(F), - BFI, MSSA}; // Setup the loop analysis manager from its proxy. It is important that diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 7cae94eb..3487e81 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -97,6 +97,12 @@ static cl::opt<MatrixLayoutTy> MatrixLayout( static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false)); +static cl::opt<unsigned> SplitMatmulRemainderOverThreshold( + "matrix-split-matmul-remainder-over-threshold", cl::Hidden, + cl::desc("Illegal remainder vectors over this size in bits should be split " + "in the inner loop of matmul"), + cl::init(0)); + /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { @@ -115,18 +121,16 @@ static bool isSplat(Value *V) { /// Match any mul operation (fp or integer). template <typename LTy, typename RTy> -auto m_AnyMul(const LTy &L, const RTy &R) { +static auto m_AnyMul(const LTy &L, const RTy &R) { return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); } /// Match any add operation (fp or integer). template <typename LTy, typename RTy> -auto m_AnyAdd(const LTy &L, const RTy &R) { +static auto m_AnyAdd(const LTy &L, const RTy &R) { return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); } -namespace { - // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) // assuming \p Stride elements between start two consecutive vectors. @@ -167,9 +171,9 @@ namespace { // v_2_0 |v_2_1 |v_2_2 |v_2_3 // v_3_0 {v_3_1 {v_3_2 v_3_3 // -Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, - unsigned NumElements, Type *EltType, - IRBuilder<> &Builder) { +static Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, + unsigned NumElements, Type *EltType, + IRBuilder<> &Builder) { assert((!isa<ConstantInt>(Stride) || cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && @@ -338,6 +342,8 @@ computeShapeInfoForInst(Instruction *I, return std::nullopt; } +namespace { + /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. /// /// Currently, the lowering for each matrix intrinsic is done as follows: @@ -371,7 +377,8 @@ class LowerMatrixIntrinsics { LoopInfo *LI = nullptr; OptimizationRemarkEmitter *ORE = nullptr; - /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. + /// Contains estimates of the number of operations (loads, stores, compute) + /// required to lower a matrix operation. struct OpInfoTy { /// Number of stores emitted to generate this matrix. unsigned NumStores = 0; @@ -1719,6 +1726,31 @@ public: ToRemove.push_back(MatMul); } + /// Given \p Remainder iterations of the the matmul inner loop, + /// potentially lower \p Blocksize that is used for the underlying + /// vector. + unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) { + if (BlockSize <= Remainder) + return BlockSize; + + // If the remainder is also a legal type just use it. + auto *VecTy = FixedVectorType::get(EltType, Remainder); + if (TTI.isTypeLegal(VecTy)) + return Remainder; + + // Similarly, if the vector is small enough that we don't want + // to split further. + if (VecTy->getPrimitiveSizeInBits() <= SplitMatmulRemainderOverThreshold) + return Remainder; + + // Gradually lower the vectorization factor to cover the + // remainder. + do { + BlockSize /= 2; + } while (BlockSize > Remainder); + return BlockSize; + } + /// Compute \p Result += \p A * \p B for input matrices with left-associating /// addition. /// @@ -1756,10 +1788,8 @@ public: bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); for (unsigned I = 0; I < R; I += BlockSize) { - // Gradually lower the vectorization factor to cover the remainder. - while (I + BlockSize > R) - BlockSize /= 2; - + // Lower block size to make sure we stay within bounds. + BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType()); Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) : nullptr; for (unsigned K = 0; K < M; ++K) { @@ -1784,9 +1814,8 @@ public: unsigned BlockSize = VF; bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); for (unsigned J = 0; J < C; J += BlockSize) { - // Gradually lower the vectorization factor to cover the remainder. - while (J + BlockSize > C) - BlockSize /= 2; + // Lower the vectorization factor to cover the remainder. + BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType()); Value *Sum = nullptr; for (unsigned K = 0; K < M; ++K) { diff --git a/llvm/lib/Transforms/Scalar/Reg2Mem.cpp b/llvm/lib/Transforms/Scalar/Reg2Mem.cpp index 30b27cb..7646624 100644 --- a/llvm/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/llvm/lib/Transforms/Scalar/Reg2Mem.cpp @@ -107,9 +107,7 @@ PreservedAnalyses RegToMemPass::run(Function &F, FunctionAnalysisManager &AM) { return PA; } -namespace llvm { - -void initializeRegToMemWrapperPassPass(PassRegistry &); +namespace { class RegToMemWrapperPass : public FunctionPass { public: @@ -136,7 +134,7 @@ public: return N != 0 || Changed; } }; -} // namespace llvm +} // namespace INITIALIZE_PASS_BEGIN(RegToMemWrapperPass, "reg2mem", "", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index a692009..5c60fad 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -344,6 +344,12 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, uint64_t SliceSizeInBits, Instruction *OldInst, Instruction *Inst, Value *Dest, Value *Value, const DataLayout &DL) { + // If we want allocas to be migrated using this helper then we need to ensure + // that the BaseFragments map code still works. A simple solution would be + // to choose to always clone alloca dbg_assigns (rather than sometimes + // "stealing" them). + assert(!isa<AllocaInst>(Inst) && "Unexpected alloca"); + auto DVRAssignMarkerRange = at::getDVRAssignmentMarkers(OldInst); // Nothing to do if OldInst has no linked dbg.assign intrinsics. if (DVRAssignMarkerRange.empty()) @@ -429,11 +435,22 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, Inst->setMetadata(LLVMContext::MD_DIAssignID, NewID); } - ::Value *NewValue = Value ? Value : DbgAssign->getValue(); - DbgVariableRecord *NewAssign = cast<DbgVariableRecord>(cast<DbgRecord *>( - DIB.insertDbgAssign(Inst, NewValue, DbgAssign->getVariable(), Expr, - Dest, DIExpression::get(Expr->getContext(), {}), - DbgAssign->getDebugLoc()))); + DbgVariableRecord *NewAssign; + if (IsSplit) { + ::Value *NewValue = Value ? Value : DbgAssign->getValue(); + NewAssign = cast<DbgVariableRecord>(cast<DbgRecord *>( + DIB.insertDbgAssign(Inst, NewValue, DbgAssign->getVariable(), Expr, + Dest, DIExpression::get(Expr->getContext(), {}), + DbgAssign->getDebugLoc()))); + } else { + // The store is not split, simply steal the existing dbg_assign. + NewAssign = DbgAssign; + NewAssign->setAssignId(NewID); // FIXME: Can we avoid generating new IDs? + NewAssign->setAddress(Dest); + if (Value) + NewAssign->replaceVariableLocationOp(0u, Value); + assert(Expr == NewAssign->getExpression()); + } // If we've updated the value but the original dbg.assign has an arglist // then kill it now - we can't use the requested new value. @@ -464,9 +481,10 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, // noted as slightly offset (in code) from the store. In practice this // should have little effect on the debugging experience due to the fact // that all the split stores should get the same line number. - NewAssign->moveBefore(DbgAssign->getIterator()); - - NewAssign->setDebugLoc(DbgAssign->getDebugLoc()); + if (NewAssign != DbgAssign) { + NewAssign->moveBefore(DbgAssign->getIterator()); + NewAssign->setDebugLoc(DbgAssign->getDebugLoc()); + } LLVM_DEBUG(dbgs() << "Created new assign: " << *NewAssign << "\n"); }; diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index e4ba70d..5af6c96 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -27,7 +27,6 @@ #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -3611,8 +3610,7 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, AAResults &AA, TargetTransformInfo &TTI, bool Trivial, bool NonTrivial, ScalarEvolution *SE, - MemorySSAUpdater *MSSAU, ProfileSummaryInfo *PSI, - BlockFrequencyInfo *BFI, LPMUpdater &LoopUpdater) { + MemorySSAUpdater *MSSAU, LPMUpdater &LoopUpdater) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -3652,35 +3650,6 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, if (F->hasOptSize()) return false; - // Returns true if Loop L's loop nest is cold, i.e. if the headers of L, - // of the loops L is nested in, and of the loops nested in L are all cold. - auto IsLoopNestCold = [&](const Loop *L) { - // Check L and all of its parent loops. - auto *Parent = L; - while (Parent) { - if (!PSI->isColdBlock(Parent->getHeader(), BFI)) - return false; - Parent = Parent->getParentLoop(); - } - // Next check all loops nested within L. - SmallVector<const Loop *, 4> Worklist; - llvm::append_range(Worklist, L->getSubLoops()); - while (!Worklist.empty()) { - auto *CurLoop = Worklist.pop_back_val(); - if (!PSI->isColdBlock(CurLoop->getHeader(), BFI)) - return false; - llvm::append_range(Worklist, CurLoop->getSubLoops()); - } - return true; - }; - - // Skip cold loops in cold loop nests, as unswitching them brings little - // benefit but increases the code size - if (PSI && PSI->hasProfileSummary() && BFI && IsLoopNestCold(&L)) { - LLVM_DEBUG(dbgs() << " Skip cold loop: " << L << "\n"); - return false; - } - // Perform legality checks. if (!isSafeForNoNTrivialUnswitching(L, LI)) return false; @@ -3705,11 +3674,6 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, LPMUpdater &U) { Function &F = *L.getHeader()->getParent(); (void)F; - ProfileSummaryInfo *PSI = nullptr; - if (auto OuterProxy = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR) - .getCachedResult<ModuleAnalysisManagerFunctionProxy>(F)) - PSI = OuterProxy->getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); @@ -3720,7 +3684,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, AR.MSSA->verifyMemorySSA(); } if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, - &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, U)) + &AR.SE, MSSAU ? &*MSSAU : nullptr, U)) return PreservedAnalyses::all(); if (AR.MSSA && VerifyMemorySSA) diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 9693ae6..b80c3c9 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -634,18 +634,10 @@ private: /// Merge \p MergeWithV into \p IV and push \p V to the worklist, if \p IV /// changes. bool mergeInValue(ValueLatticeElement &IV, Value *V, - ValueLatticeElement MergeWithV, + const ValueLatticeElement &MergeWithV, ValueLatticeElement::MergeOptions Opts = { /*MayIncludeUndef=*/false, /*CheckWiden=*/false}); - bool mergeInValue(Value *V, ValueLatticeElement MergeWithV, - ValueLatticeElement::MergeOptions Opts = { - /*MayIncludeUndef=*/false, /*CheckWiden=*/false}) { - assert(!V->getType()->isStructTy() && - "non-structs should use markConstant"); - return mergeInValue(ValueState[V], V, MergeWithV, Opts); - } - /// getValueState - Return the ValueLatticeElement object that corresponds to /// the value. This function handles the case when the value hasn't been seen /// yet by properly seeding constants etc. @@ -987,7 +979,7 @@ public: void trackValueOfArgument(Argument *A) { if (A->getType()->isStructTy()) return (void)markOverdefined(A); - mergeInValue(A, getArgAttributeVL(A)); + mergeInValue(ValueState[A], A, getArgAttributeVL(A)); } bool isStructLatticeConstant(Function *F, StructType *STy); @@ -1128,8 +1120,7 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) { for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { const auto &It = TrackedMultipleRetVals.find(std::make_pair(F, i)); assert(It != TrackedMultipleRetVals.end()); - ValueLatticeElement LV = It->second; - if (!SCCPSolver::isConstant(LV)) + if (!SCCPSolver::isConstant(It->second)) return false; } return true; @@ -1160,7 +1151,7 @@ Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const { std::vector<Constant *> ConstVals; auto *ST = cast<StructType>(V->getType()); for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) { - ValueLatticeElement LV = LVs[I]; + const ValueLatticeElement &LV = LVs[I]; ConstVals.push_back(SCCPSolver::isConstant(LV) ? getConstant(LV, ST->getElementType(I)) : UndefValue::get(ST->getElementType(I))); @@ -1225,7 +1216,7 @@ void SCCPInstVisitor::visitInstruction(Instruction &I) { } bool SCCPInstVisitor::mergeInValue(ValueLatticeElement &IV, Value *V, - ValueLatticeElement MergeWithV, + const ValueLatticeElement &MergeWithV, ValueLatticeElement::MergeOptions Opts) { if (IV.mergeIn(MergeWithV, Opts)) { pushUsersToWorkList(V); @@ -1264,7 +1255,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, return; } - ValueLatticeElement BCValue = getValueState(BI->getCondition()); + const ValueLatticeElement &BCValue = getValueState(BI->getCondition()); ConstantInt *CI = getConstantInt(BCValue, BI->getCondition()->getType()); if (!CI) { // Overdefined condition variables, and branches on unfoldable constant @@ -1326,7 +1317,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, // the target as executable. if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) { // Casts are folded by visitCastInst. - ValueLatticeElement IBRValue = getValueState(IBR->getAddress()); + const ValueLatticeElement &IBRValue = getValueState(IBR->getAddress()); BlockAddress *Addr = dyn_cast_or_null<BlockAddress>( getConstant(IBRValue, IBR->getAddress()->getType())); if (!Addr) { // Overdefined or unknown condition? @@ -1408,7 +1399,7 @@ void SCCPInstVisitor::visitPHINode(PHINode &PN) { if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent())) continue; - ValueLatticeElement IV = getValueState(PN.getIncomingValue(i)); + const ValueLatticeElement &IV = getValueState(PN.getIncomingValue(i)); PhiState.mergeIn(IV); NumActiveIncoming++; if (PhiState.isOverdefined()) @@ -1420,10 +1411,10 @@ void SCCPInstVisitor::visitPHINode(PHINode &PN) { // extensions to match the number of active incoming values. This helps to // limit multiple extensions caused by the same incoming value, if other // incoming values are equal. - mergeInValue(&PN, PhiState, + ValueLatticeElement &PhiStateRef = ValueState[&PN]; + mergeInValue(PhiStateRef, &PN, PhiState, ValueLatticeElement::MergeOptions().setMaxWidenSteps( NumActiveIncoming + 1)); - ValueLatticeElement &PhiStateRef = getValueState(&PN); PhiStateRef.setNumRangeExtensions( std::max(NumActiveIncoming, PhiStateRef.getNumRangeExtensions())); } @@ -1481,7 +1472,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { } } - ValueLatticeElement OpSt = getValueState(I.getOperand(0)); + const ValueLatticeElement &OpSt = getValueState(I.getOperand(0)); if (OpSt.isUnknownOrUndef()) return; @@ -1496,9 +1487,9 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { if (I.getDestTy()->isIntOrIntVectorTy() && I.getSrcTy()->isIntOrIntVectorTy() && I.getOpcode() != Instruction::BitCast) { - auto &LV = getValueState(&I); ConstantRange OpRange = OpSt.asConstantRange(I.getSrcTy(), /*UndefAllowed=*/false); + auto &LV = getValueState(&I); Type *DestTy = I.getDestTy(); ConstantRange Res = ConstantRange::getEmpty(DestTy->getScalarSizeInBits()); @@ -1516,19 +1507,24 @@ void SCCPInstVisitor::handleExtractOfWithOverflow(ExtractValueInst &EVI, const WithOverflowInst *WO, unsigned Idx) { Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); - ValueLatticeElement L = getValueState(LHS); - ValueLatticeElement R = getValueState(RHS); + Type *Ty = LHS->getType(); + addAdditionalUser(LHS, &EVI); addAdditionalUser(RHS, &EVI); - if (L.isUnknownOrUndef() || R.isUnknownOrUndef()) - return; // Wait to resolve. - Type *Ty = LHS->getType(); + const ValueLatticeElement &L = getValueState(LHS); + if (L.isUnknownOrUndef()) + return; // Wait to resolve. ConstantRange LR = L.asConstantRange(Ty, /*UndefAllowed=*/false); + + const ValueLatticeElement &R = getValueState(RHS); + if (R.isUnknownOrUndef()) + return; // Wait to resolve. + ConstantRange RR = R.asConstantRange(Ty, /*UndefAllowed=*/false); if (Idx == 0) { ConstantRange Res = LR.binaryOp(WO->getBinaryOp(), RR); - mergeInValue(&EVI, ValueLatticeElement::getRange(Res)); + mergeInValue(ValueState[&EVI], &EVI, ValueLatticeElement::getRange(Res)); } else { assert(Idx == 1 && "Index can only be 0 or 1"); ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( @@ -1560,7 +1556,7 @@ void SCCPInstVisitor::visitExtractValueInst(ExtractValueInst &EVI) { if (auto *WO = dyn_cast<WithOverflowInst>(AggVal)) return handleExtractOfWithOverflow(EVI, WO, i); ValueLatticeElement EltVal = getStructValueState(AggVal, i); - mergeInValue(getValueState(&EVI), &EVI, EltVal); + mergeInValue(ValueState[&EVI], &EVI, EltVal); } else { // Otherwise, must be extracting from an array. return (void)markOverdefined(&EVI); @@ -1616,14 +1612,18 @@ void SCCPInstVisitor::visitSelectInst(SelectInst &I) { if (ValueState[&I].isOverdefined()) return (void)markOverdefined(&I); - ValueLatticeElement CondValue = getValueState(I.getCondition()); + const ValueLatticeElement &CondValue = getValueState(I.getCondition()); if (CondValue.isUnknownOrUndef()) return; if (ConstantInt *CondCB = getConstantInt(CondValue, I.getCondition()->getType())) { Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue(); - mergeInValue(&I, getValueState(OpVal)); + const ValueLatticeElement &OpValState = getValueState(OpVal); + // Safety: ValueState[&I] doesn't invalidate OpValState since it is already + // in the map. + assert(ValueState.contains(&I) && "&I is not in ValueState map."); + mergeInValue(ValueState[&I], &I, OpValState); return; } @@ -1721,7 +1721,7 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { // being a special floating value. ValueLatticeElement NewV; NewV.markConstant(C, /*MayIncludeUndef=*/true); - return (void)mergeInValue(&I, NewV); + return (void)mergeInValue(ValueState[&I], &I, NewV); } } @@ -1741,7 +1741,7 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { R = A.overflowingBinaryOp(BO->getOpcode(), B, OBO->getNoWrapKind()); else R = A.binaryOp(BO->getOpcode(), B); - mergeInValue(&I, ValueLatticeElement::getRange(R)); + mergeInValue(ValueState[&I], &I, ValueLatticeElement::getRange(R)); // TODO: Currently we do not exploit special values that produce something // better than overdefined with an overdefined operand for vector or floating @@ -1767,7 +1767,7 @@ void SCCPInstVisitor::visitCmpInst(CmpInst &I) { if (C) { ValueLatticeElement CV; CV.markConstant(C); - mergeInValue(&I, CV); + mergeInValue(ValueState[&I], &I, CV); return; } @@ -1802,7 +1802,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { Operands.reserve(I.getNumOperands()); for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { - ValueLatticeElement State = getValueState(I.getOperand(i)); + const ValueLatticeElement &State = getValueState(I.getOperand(i)); if (State.isUnknownOrUndef()) return; // Operands are not resolved yet. @@ -1881,14 +1881,13 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) { if (ValueState[&I].isOverdefined()) return (void)markOverdefined(&I); - ValueLatticeElement PtrVal = getValueState(I.getOperand(0)); + const ValueLatticeElement &PtrVal = getValueState(I.getOperand(0)); if (PtrVal.isUnknownOrUndef()) return; // The pointer is not resolved yet! - ValueLatticeElement &IV = ValueState[&I]; - if (SCCPSolver::isConstant(PtrVal)) { Constant *Ptr = getConstant(PtrVal, I.getOperand(0)->getType()); + ValueLatticeElement &IV = ValueState[&I]; // load null is undefined. if (isa<ConstantPointerNull>(Ptr)) { @@ -1916,7 +1915,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) { } // Fall back to metadata. - mergeInValue(&I, getValueFromMetadata(&I)); + mergeInValue(ValueState[&I], &I, getValueFromMetadata(&I)); } void SCCPInstVisitor::visitCallBase(CallBase &CB) { @@ -1944,7 +1943,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { return markOverdefined(&CB); // Can't handle struct args. if (A.get()->getType()->isMetadataTy()) continue; // Carried in CB, not allowed in Operands. - ValueLatticeElement State = getValueState(A); + const ValueLatticeElement &State = getValueState(A); if (State.isUnknownOrUndef()) return; // Operands are not resolved yet. @@ -1964,7 +1963,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { } // Fall back to metadata. - mergeInValue(&CB, getValueFromMetadata(&CB)); + mergeInValue(ValueState[&CB], &CB, getValueFromMetadata(&CB)); } void SCCPInstVisitor::handleCallArguments(CallBase &CB) { @@ -1992,10 +1991,11 @@ void SCCPInstVisitor::handleCallArguments(CallBase &CB) { mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg, getMaxWidenStepsOpts()); } - } else - mergeInValue(&*AI, - getValueState(*CAI).intersect(getArgAttributeVL(&*AI)), - getMaxWidenStepsOpts()); + } else { + ValueLatticeElement CallArg = + getValueState(*CAI).intersect(getArgAttributeVL(&*AI)); + mergeInValue(ValueState[&*AI], &*AI, CallArg, getMaxWidenStepsOpts()); + } } } } @@ -2076,7 +2076,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { if (II->getIntrinsicID() == Intrinsic::vscale) { unsigned BitWidth = CB.getType()->getScalarSizeInBits(); const ConstantRange Result = getVScaleRange(II->getFunction(), BitWidth); - return (void)mergeInValue(II, ValueLatticeElement::getRange(Result)); + return (void)mergeInValue(ValueState[II], II, + ValueLatticeElement::getRange(Result)); } if (ConstantRange::isIntrinsicSupported(II->getIntrinsicID())) { @@ -2094,7 +2095,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { ConstantRange Result = ConstantRange::intrinsic(II->getIntrinsicID(), OpRanges); - return (void)mergeInValue(II, ValueLatticeElement::getRange(Result)); + return (void)mergeInValue(ValueState[II], II, + ValueLatticeElement::getRange(Result)); } } @@ -2121,7 +2123,7 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { return handleCallOverdefined(CB); // Not tracking this callee. // If so, propagate the return value of the callee into this call result. - mergeInValue(&CB, TFRVI->second, getMaxWidenStepsOpts()); + mergeInValue(ValueState[&CB], &CB, TFRVI->second, getMaxWidenStepsOpts()); } } diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 88af2cf..9cd52da 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -2242,8 +2242,49 @@ public: /// may not be necessary. bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const; bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy, - Align Alignment, const int64_t Diff, Value *Ptr0, - Value *PtrN, StridedPtrInfo &SPtrInfo) const; + Align Alignment, const int64_t Diff, + const size_t Sz) const; + + /// Return true if an array of scalar loads can be replaced with a strided + /// load (with constant stride). + /// + /// TODO: + /// It is possible that the load gets "widened". Suppose that originally each + /// load loads `k` bytes and `PointerOps` can be arranged as follows (`%s` is + /// constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2 + /// ... + /// %b + 0 * %s + (w - 1) + /// + /// %b + 1 * %s + 0 + /// %b + 1 * %s + 1 + /// %b + 1 * %s + 2 + /// ... + /// %b + 1 * %s + (w - 1) + /// ... + /// + /// %b + (n - 1) * %s + 0 + /// %b + (n - 1) * %s + 1 + /// %b + (n - 1) * %s + 2 + /// ... + /// %b + (n - 1) * %s + (w - 1) + /// + /// In this case we will generate a strided load of type `<n x (k * w)>`. + /// + /// \param PointerOps list of pointer arguments of loads. + /// \param ElemTy original scalar type of loads. + /// \param Alignment alignment of the first load. + /// \param SortedIndices is the order of PointerOps as returned by + /// `sortPtrAccesses` + /// \param Diff Pointer difference between the lowest and the highes pointer + /// in `PointerOps` as returned by `getPointersDiff`. + /// \param Ptr0 first pointer in `PointersOps`. + /// \param PtrN last pointer in `PointersOps`. + /// \param SPtrInfo If the function return `true`, it also sets all the fields + /// of `SPtrInfo` necessary to generate the strided load later. + bool analyzeConstantStrideCandidate( + const ArrayRef<Value *> PointerOps, Type *ElemTy, Align Alignment, + const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff, + Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const; /// Return true if an array of scalar loads can be replaced with a strided /// load (with run-time stride). @@ -6849,9 +6890,8 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, /// current graph (for masked gathers extra extractelement instructions /// might be required). bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy, - Align Alignment, const int64_t Diff, Value *Ptr0, - Value *PtrN, StridedPtrInfo &SPtrInfo) const { - const size_t Sz = PointerOps.size(); + Align Alignment, const int64_t Diff, + const size_t Sz) const { if (Diff % (Sz - 1) != 0) return false; @@ -6875,27 +6915,40 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy, return false; if (!TTI->isLegalStridedLoadStore(VecTy, Alignment)) return false; + return true; + } + return false; +} - // Iterate through all pointers and check if all distances are - // unique multiple of Dist. - SmallSet<int64_t, 4> Dists; - for (Value *Ptr : PointerOps) { - int64_t Dist = 0; - if (Ptr == PtrN) - Dist = Diff; - else if (Ptr != Ptr0) - Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE); - // If the strides are not the same or repeated, we can't - // vectorize. - if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) - break; - } - if (Dists.size() == Sz) { - Type *StrideTy = DL->getIndexType(Ptr0->getType()); - SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); - SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); - return true; - } +bool BoUpSLP::analyzeConstantStrideCandidate( + const ArrayRef<Value *> PointerOps, Type *ScalarTy, Align Alignment, + const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff, + Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const { + const size_t Sz = PointerOps.size(); + if (!isStridedLoad(PointerOps, ScalarTy, Alignment, Diff, Sz)) + return false; + + int64_t Stride = Diff / static_cast<int64_t>(Sz - 1); + + // Iterate through all pointers and check if all distances are + // unique multiple of Dist. + SmallSet<int64_t, 4> Dists; + for (Value *Ptr : PointerOps) { + int64_t Dist = 0; + if (Ptr == PtrN) + Dist = Diff; + else if (Ptr != Ptr0) + Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE); + // If the strides are not the same or repeated, we can't + // vectorize. + if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) + break; + } + if (Dists.size() == Sz) { + Type *StrideTy = DL->getIndexType(Ptr0->getType()); + SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); + SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); + return true; } return false; } @@ -6995,8 +7048,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( Align Alignment = cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()]) ->getAlign(); - if (isStridedLoad(PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN, - SPtrInfo)) + if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, Alignment, Order, + *Diff, Ptr0, PtrN, SPtrInfo)) return LoadsState::StridedVectorize; } if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) || @@ -17632,7 +17685,9 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { } if (IsPHI || (!E->isGather() && E->State != TreeEntry::SplitVectorize && - E->doesNotNeedToSchedule()) || + (E->doesNotNeedToSchedule() || + (E->hasCopyableElements() && !E->isCopyableElement(LastInst) && + isUsedOutsideBlock(LastInst)))) || (GatheredLoadsEntriesFirst.has_value() && E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && E->getOpcode() == Instruction::Load)) { |