diff options
Diffstat (limited to 'llvm/lib')
104 files changed, 2471 insertions, 1781 deletions
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt index cfde787..16dd6f8 100644 --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -175,6 +175,7 @@ add_llvm_component_library(LLVMAnalysis LINK_COMPONENTS BinaryFormat Core + FrontendHLSL Object ProfileData Support diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 2d52f34..dd98b62 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -2679,11 +2679,12 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, case Intrinsic::nvvm_round_ftz_f: case Intrinsic::nvvm_round_f: case Intrinsic::nvvm_round_d: { - // Use APFloat implementation instead of native libm call, as some - // implementations (e.g. on PPC) do not preserve the sign of negative 0. + // nvvm_round is lowered to PTX cvt.rni, which will round to nearest + // integer, choosing even integer if source is equidistant between two + // integers, so the semantics are closer to "rint" rather than "round". bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID); auto V = IsFTZ ? FTZPreserveSign(APF) : APF; - V.roundToIntegral(APFloat::rmNearestTiesToAway); + V.roundToIntegral(APFloat::rmNearestTiesToEven); return ConstantFP::get(Ty->getContext(), V); } diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp index 1959ab6..629fa7cd 100644 --- a/llvm/lib/Analysis/DXILResource.cpp +++ b/llvm/lib/Analysis/DXILResource.cpp @@ -995,18 +995,7 @@ SmallVector<dxil::ResourceInfo *> DXILResourceMap::findByUse(const Value *Key) { //===----------------------------------------------------------------------===// void DXILResourceBindingInfo::populate(Module &M, DXILResourceTypeMap &DRTM) { - struct Binding { - ResourceClass RC; - uint32_t Space; - uint32_t LowerBound; - uint32_t UpperBound; - Value *Name; - Binding(ResourceClass RC, uint32_t Space, uint32_t LowerBound, - uint32_t UpperBound, Value *Name) - : RC(RC), Space(Space), LowerBound(LowerBound), UpperBound(UpperBound), - Name(Name) {} - }; - SmallVector<Binding> Bindings; + hlsl::BindingInfoBuilder Builder; // collect all of the llvm.dx.resource.handlefrombinding calls; // make a note if there is llvm.dx.resource.handlefromimplicitbinding @@ -1036,133 +1025,20 @@ void DXILResourceBindingInfo::populate(Module &M, DXILResourceTypeMap &DRTM) { assert((Size < 0 || (unsigned)LowerBound + Size - 1 <= UINT32_MAX) && "upper bound register overflow"); uint32_t UpperBound = Size < 0 ? UINT32_MAX : LowerBound + Size - 1; - Bindings.emplace_back(RTI.getResourceClass(), Space, LowerBound, - UpperBound, Name); + Builder.trackBinding(RTI.getResourceClass(), Space, LowerBound, + UpperBound, Name); } break; } case Intrinsic::dx_resource_handlefromimplicitbinding: { - ImplicitBinding = true; + HasImplicitBinding = true; break; } } } - // sort all the collected bindings - llvm::stable_sort(Bindings, [](auto &LHS, auto &RHS) { - return std::tie(LHS.RC, LHS.Space, LHS.LowerBound) < - std::tie(RHS.RC, RHS.Space, RHS.LowerBound); - }); - - // remove duplicates - Binding *NewEnd = llvm::unique(Bindings, [](auto &LHS, auto &RHS) { - return std::tie(LHS.RC, LHS.Space, LHS.LowerBound, LHS.UpperBound, - LHS.Name) == std::tie(RHS.RC, RHS.Space, RHS.LowerBound, - RHS.UpperBound, RHS.Name); - }); - if (NewEnd != Bindings.end()) - Bindings.erase(NewEnd); - - // Go over the sorted bindings and build up lists of free register ranges - // for each binding type and used spaces. Bindings are sorted by resource - // class, space, and lower bound register slot. - BindingSpaces *BS = &SRVSpaces; - for (const Binding &B : Bindings) { - if (BS->RC != B.RC) - // move to the next resource class spaces - BS = &getBindingSpaces(B.RC); - - RegisterSpace *S = BS->Spaces.empty() ? &BS->Spaces.emplace_back(B.Space) - : &BS->Spaces.back(); - assert(S->Space <= B.Space && "bindings not sorted correctly?"); - if (B.Space != S->Space) - // add new space - S = &BS->Spaces.emplace_back(B.Space); - - // The space is full - there are no free slots left, or the rest of the - // slots are taken by an unbounded array. Set flag to report overlapping - // binding later. - if (S->FreeRanges.empty() || S->FreeRanges.back().UpperBound < UINT32_MAX) { - OverlappingBinding = true; - continue; - } - - // adjust the last free range lower bound, split it in two, or remove it - BindingRange &LastFreeRange = S->FreeRanges.back(); - if (LastFreeRange.LowerBound == B.LowerBound) { - if (B.UpperBound < UINT32_MAX) - LastFreeRange.LowerBound = B.UpperBound + 1; - else - S->FreeRanges.pop_back(); - } else if (LastFreeRange.LowerBound < B.LowerBound) { - LastFreeRange.UpperBound = B.LowerBound - 1; - if (B.UpperBound < UINT32_MAX) - S->FreeRanges.emplace_back(B.UpperBound + 1, UINT32_MAX); - } else { - OverlappingBinding = true; - if (B.UpperBound < UINT32_MAX) - LastFreeRange.LowerBound = - std::max(LastFreeRange.LowerBound, B.UpperBound + 1); - else - S->FreeRanges.pop_back(); - } - } -} - -// returns std::nulopt if binding could not be found in given space -std::optional<uint32_t> -DXILResourceBindingInfo::findAvailableBinding(dxil::ResourceClass RC, - uint32_t Space, int32_t Size) { - BindingSpaces &BS = getBindingSpaces(RC); - RegisterSpace &RS = BS.getOrInsertSpace(Space); - return RS.findAvailableBinding(Size); -} - -DXILResourceBindingInfo::RegisterSpace & -DXILResourceBindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) { - for (auto *I = Spaces.begin(); I != Spaces.end(); ++I) { - if (I->Space == Space) - return *I; - if (I->Space < Space) - continue; - return *Spaces.insert(I, Space); - } - return Spaces.emplace_back(Space); -} - -std::optional<uint32_t> -DXILResourceBindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) { - assert((Size == -1 || Size > 0) && "invalid size"); - - if (FreeRanges.empty()) - return std::nullopt; - - // unbounded array - if (Size == -1) { - BindingRange &Last = FreeRanges.back(); - if (Last.UpperBound != UINT32_MAX) - // this space is already occupied by an unbounded array - return std::nullopt; - uint32_t RegSlot = Last.LowerBound; - FreeRanges.pop_back(); - return RegSlot; - } - - // single resource or fixed-size array - for (BindingRange &R : FreeRanges) { - // compare the size as uint64_t to prevent overflow for range (0, - // UINT32_MAX) - if ((uint64_t)R.UpperBound - R.LowerBound + 1 < (uint64_t)Size) - continue; - uint32_t RegSlot = R.LowerBound; - // This might create a range where (LowerBound == UpperBound + 1). When - // that happens, the next time this function is called the range will - // skipped over by the check above (at this point Size is always > 0). - R.LowerBound += Size; - return RegSlot; - } - - return std::nullopt; + Bindings = Builder.calculateBindingInfo( + [this](auto, auto) { this->HasOverlappingBinding = true; }); } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp index 393f264..6fc81d787 100644 --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -342,7 +342,7 @@ bool llvm::isDereferenceableAndAlignedInLoop( : SE.getConstantMaxBackedgeTakenCount(L); } const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess( - L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr); + L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr, &DT, AC); if (isa<SCEVCouldNotCompute>(AccessStart) || isa<SCEVCouldNotCompute>(AccessEnd)) return false; diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 14be385..a553533 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -23,6 +23,8 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AssumeBundleQueries.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" @@ -208,28 +210,46 @@ static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B, /// Return true, if evaluating \p AR at \p MaxBTC cannot wrap, because \p AR at /// \p MaxBTC is guaranteed inbounds of the accessed object. -static bool evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, - const SCEV *MaxBTC, - const SCEV *EltSize, - ScalarEvolution &SE, - const DataLayout &DL) { +static bool +evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, + const SCEV *MaxBTC, const SCEV *EltSize, + ScalarEvolution &SE, const DataLayout &DL, + DominatorTree *DT, AssumptionCache *AC) { auto *PointerBase = SE.getPointerBase(AR->getStart()); auto *StartPtr = dyn_cast<SCEVUnknown>(PointerBase); if (!StartPtr) return false; + const Loop *L = AR->getLoop(); bool CheckForNonNull, CheckForFreed; - uint64_t DerefBytes = StartPtr->getValue()->getPointerDereferenceableBytes( + Value *StartPtrV = StartPtr->getValue(); + uint64_t DerefBytes = StartPtrV->getPointerDereferenceableBytes( DL, CheckForNonNull, CheckForFreed); - if (CheckForNonNull || CheckForFreed) + if (DerefBytes && (CheckForNonNull || CheckForFreed)) return false; const SCEV *Step = AR->getStepRecurrence(SE); + Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType()); + const SCEV *DerefBytesSCEV = SE.getConstant(WiderTy, DerefBytes); + + // Check if we have a suitable dereferencable assumption we can use. + if (!StartPtrV->canBeFreed()) { + RetainedKnowledge DerefRK = getKnowledgeValidInContext( + StartPtrV, {Attribute::Dereferenceable}, *AC, + L->getLoopPredecessor()->getTerminator(), DT); + if (DerefRK) { + DerefBytesSCEV = SE.getUMaxExpr( + DerefBytesSCEV, SE.getConstant(WiderTy, DerefRK.ArgValue)); + } + } + + if (DerefBytesSCEV->isZero()) + return false; + bool IsKnownNonNegative = SE.isKnownNonNegative(Step); if (!IsKnownNonNegative && !SE.isKnownNegative(Step)) return false; - Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType()); Step = SE.getNoopOrSignExtend(Step, WiderTy); MaxBTC = SE.getNoopOrZeroExtend(MaxBTC, WiderTy); @@ -256,8 +276,7 @@ static bool evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, const SCEV *EndBytes = addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE); if (!EndBytes) return false; - return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, - SE.getConstant(WiderTy, DerefBytes)); + return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV); } // For negative steps check if @@ -265,15 +284,15 @@ static bool evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, // * StartOffset <= DerefBytes. assert(SE.isKnownNegative(Step) && "must be known negative"); return SE.isKnownPredicate(CmpInst::ICMP_SGE, StartOffset, OffsetEndBytes) && - SE.isKnownPredicate(CmpInst::ICMP_ULE, StartOffset, - SE.getConstant(WiderTy, DerefBytes)); + SE.isKnownPredicate(CmpInst::ICMP_ULE, StartOffset, DerefBytesSCEV); } std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess( const Loop *Lp, const SCEV *PtrExpr, Type *AccessTy, const SCEV *BTC, const SCEV *MaxBTC, ScalarEvolution *SE, DenseMap<std::pair<const SCEV *, Type *>, - std::pair<const SCEV *, const SCEV *>> *PointerBounds) { + std::pair<const SCEV *, const SCEV *>> *PointerBounds, + DominatorTree *DT, AssumptionCache *AC) { std::pair<const SCEV *, const SCEV *> *PtrBoundsPair; if (PointerBounds) { auto [Iter, Ins] = PointerBounds->insert( @@ -308,8 +327,8 @@ std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess( // sets ScEnd to the maximum unsigned value for the type. Note that LAA // separately checks that accesses cannot not wrap, so unsigned max // represents an upper bound. - if (evaluatePtrAddRecAtMaxBTCWillNotWrap(AR, MaxBTC, EltSizeSCEV, *SE, - DL)) { + if (evaluatePtrAddRecAtMaxBTCWillNotWrap(AR, MaxBTC, EltSizeSCEV, *SE, DL, + DT, AC)) { ScEnd = AR->evaluateAtIteration(MaxBTC, *SE); } else { ScEnd = SE->getAddExpr( @@ -356,9 +375,9 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, bool NeedsFreeze) { const SCEV *SymbolicMaxBTC = PSE.getSymbolicMaxBackedgeTakenCount(); const SCEV *BTC = PSE.getBackedgeTakenCount(); - const auto &[ScStart, ScEnd] = - getStartAndEndForAccess(Lp, PtrExpr, AccessTy, BTC, SymbolicMaxBTC, - PSE.getSE(), &DC.getPointerBounds()); + const auto &[ScStart, ScEnd] = getStartAndEndForAccess( + Lp, PtrExpr, AccessTy, BTC, SymbolicMaxBTC, PSE.getSE(), + &DC.getPointerBounds(), DC.getDT(), DC.getAC()); assert(!isa<SCEVCouldNotCompute>(ScStart) && !isa<SCEVCouldNotCompute>(ScEnd) && "must be able to compute both start and end expressions"); @@ -1961,13 +1980,15 @@ bool MemoryDepChecker::areAccessesCompletelyBeforeOrAfter(const SCEV *Src, const SCEV *BTC = PSE.getBackedgeTakenCount(); const SCEV *SymbolicMaxBTC = PSE.getSymbolicMaxBackedgeTakenCount(); ScalarEvolution &SE = *PSE.getSE(); - const auto &[SrcStart_, SrcEnd_] = getStartAndEndForAccess( - InnermostLoop, Src, SrcTy, BTC, SymbolicMaxBTC, &SE, &PointerBounds); + const auto &[SrcStart_, SrcEnd_] = + getStartAndEndForAccess(InnermostLoop, Src, SrcTy, BTC, SymbolicMaxBTC, + &SE, &PointerBounds, DT, AC); if (isa<SCEVCouldNotCompute>(SrcStart_) || isa<SCEVCouldNotCompute>(SrcEnd_)) return false; - const auto &[SinkStart_, SinkEnd_] = getStartAndEndForAccess( - InnermostLoop, Sink, SinkTy, BTC, SymbolicMaxBTC, &SE, &PointerBounds); + const auto &[SinkStart_, SinkEnd_] = + getStartAndEndForAccess(InnermostLoop, Sink, SinkTy, BTC, SymbolicMaxBTC, + &SE, &PointerBounds, DT, AC); if (isa<SCEVCouldNotCompute>(SinkStart_) || isa<SCEVCouldNotCompute>(SinkEnd_)) return false; @@ -3002,7 +3023,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, const TargetTransformInfo *TTI, const TargetLibraryInfo *TLI, AAResults *AA, DominatorTree *DT, LoopInfo *LI, - bool AllowPartial) + AssumptionCache *AC, bool AllowPartial) : PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)), PtrRtChecking(nullptr), TheLoop(L), AllowPartial(AllowPartial) { unsigned MaxTargetVectorWidthInBits = std::numeric_limits<unsigned>::max(); @@ -3012,8 +3033,8 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, MaxTargetVectorWidthInBits = TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) * 2; - DepChecker = std::make_unique<MemoryDepChecker>(*PSE, L, SymbolicStrides, - MaxTargetVectorWidthInBits); + DepChecker = std::make_unique<MemoryDepChecker>( + *PSE, AC, DT, L, SymbolicStrides, MaxTargetVectorWidthInBits); PtrRtChecking = std::make_unique<RuntimePointerChecking>(*DepChecker, SE); if (canAnalyzeLoop()) CanVecMem = analyzeLoop(AA, LI, TLI, DT); @@ -3082,7 +3103,7 @@ const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L, // or if it was created with a different value of AllowPartial. if (Inserted || It->second->hasAllowPartial() != AllowPartial) It->second = std::make_unique<LoopAccessInfo>(&L, &SE, TTI, TLI, &AA, &DT, - &LI, AllowPartial); + &LI, AC, AllowPartial); return *It->second; } @@ -3125,7 +3146,8 @@ LoopAccessInfoManager LoopAccessAnalysis::run(Function &F, auto &LI = FAM.getResult<LoopAnalysis>(F); auto &TTI = FAM.getResult<TargetIRAnalysis>(F); auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); - return LoopAccessInfoManager(SE, AA, DT, LI, &TTI, &TLI); + auto &AC = FAM.getResult<AssumptionAnalysis>(F); + return LoopAccessInfoManager(SE, AA, DT, LI, &TTI, &TLI, &AC); } AnalysisKey LoopAccessAnalysis::Key; diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 61a575c..477e477 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -2685,16 +2685,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext // (B), if trunc (A) + -A + B does not unsigned-wrap. - if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Ops[1])) { - const SCEV *B = ZExt->getOperand(0); - const SCEV *NarrowA = getTruncateExpr(A, B->getType()); - if (isa<SCEVAddExpr>(B) && - NarrowA == getNegativeSCEV(cast<SCEVAddExpr>(B)->getOperand(0)) && - getZeroExtendExpr(NarrowA, ZExt->getType()) == A && - hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, B}, + const SCEVAddExpr *InnerAdd; + if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) { + const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType()); + if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) && + getZeroExtendExpr(NarrowA, B->getType()) == A && + hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd}, SCEV::FlagAnyWrap), SCEV::FlagNUW)) { - return getZeroExtendExpr(getAddExpr(NarrowA, B), ZExt->getType()); + return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType()); } } } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 55ba52a..c7eb2ec 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1486,6 +1486,10 @@ void TargetTransformInfo::collectKernelLaunchBounds( return TTIImpl->collectKernelLaunchBounds(F, LB); } +bool TargetTransformInfo::allowVectorElementIndexingUsingGEP() const { + return TTIImpl->allowVectorElementIndexingUsingGEP(); +} + TargetTransformInfoImplBase::~TargetTransformInfoImplBase() = default; TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 425ea31..b3b4c37 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -81,7 +81,6 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::exp: case Intrinsic::exp10: case Intrinsic::exp2: - case Intrinsic::ldexp: case Intrinsic::log: case Intrinsic::log10: case Intrinsic::log2: @@ -109,8 +108,6 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::canonicalize: case Intrinsic::fptosi_sat: case Intrinsic::fptoui_sat: - case Intrinsic::lround: - case Intrinsic::llround: case Intrinsic::lrint: case Intrinsic::llrint: case Intrinsic::ucmp: @@ -192,8 +189,6 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg( switch (ID) { case Intrinsic::fptosi_sat: case Intrinsic::fptoui_sat: - case Intrinsic::lround: - case Intrinsic::llround: case Intrinsic::lrint: case Intrinsic::llrint: case Intrinsic::vp_lrint: @@ -208,7 +203,6 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg( case Intrinsic::vp_is_fpclass: return OpdIdx == 0; case Intrinsic::powi: - case Intrinsic::ldexp: return OpdIdx == -1 || OpdIdx == 1; default: return OpdIdx == -1; diff --git a/llvm/lib/CGData/StableFunctionMapRecord.cpp b/llvm/lib/CGData/StableFunctionMapRecord.cpp index 4e4fcef..423e068 100644 --- a/llvm/lib/CGData/StableFunctionMapRecord.cpp +++ b/llvm/lib/CGData/StableFunctionMapRecord.cpp @@ -160,14 +160,18 @@ void StableFunctionMapRecord::deserialize(const unsigned char *&Ptr, for (unsigned I = 0; I < NumFuncs; ++I) { auto Hash = endian::readNext<stable_hash, endianness::little, unaligned>(Ptr); - auto FunctionNameId = + [[maybe_unused]] auto FunctionNameId = endian::readNext<uint32_t, endianness::little, unaligned>(Ptr); - assert(FunctionMap->getNameForId(FunctionNameId) && - "FunctionNameId out of range"); - auto ModuleNameId = + [[maybe_unused]] auto ModuleNameId = endian::readNext<uint32_t, endianness::little, unaligned>(Ptr); - assert(FunctionMap->getNameForId(ModuleNameId) && - "ModuleNameId out of range"); + // Only validate IDs if we've read the names + if (ReadStableFunctionMapNames) { + assert(FunctionMap->getNameForId(FunctionNameId) && + "FunctionNameId out of range"); + assert(FunctionMap->getNameForId(ModuleNameId) && + "ModuleNameId out of range"); + } + auto InstCount = endian::readNext<uint32_t, endianness::little, unaligned>(Ptr); diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index 6166271..1641c3e 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -1654,6 +1654,88 @@ void AsmPrinter::emitStackUsage(const MachineFunction &MF) { *StackUsageStream << "static\n"; } +/// Extracts a generalized numeric type identifier of a Function's type from +/// type metadata. Returns null if metadata cannot be found. +static ConstantInt *extractNumericCGTypeId(const Function &F) { + SmallVector<MDNode *, 2> Types; + F.getMetadata(LLVMContext::MD_type, Types); + for (const auto &Type : Types) { + if (Type->hasGeneralizedMDString()) { + MDString *MDGeneralizedTypeId = cast<MDString>(Type->getOperand(1)); + uint64_t TypeIdVal = llvm::MD5Hash(MDGeneralizedTypeId->getString()); + IntegerType *Int64Ty = Type::getInt64Ty(F.getContext()); + return ConstantInt::get(Int64Ty, TypeIdVal); + } + } + return nullptr; +} + +/// Emits .callgraph section. +void AsmPrinter::emitCallGraphSection(const MachineFunction &MF, + FunctionInfo &FuncInfo) { + if (!MF.getTarget().Options.EmitCallGraphSection) + return; + + // Switch to the call graph section for the function + MCSection *FuncCGSection = + getObjFileLowering().getCallGraphSection(*getCurrentSection()); + assert(FuncCGSection && "null callgraph section"); + OutStreamer->pushSection(); + OutStreamer->switchSection(FuncCGSection); + + // Emit format version number. + OutStreamer->emitInt64(CallGraphSectionFormatVersion::V_0); + + // Emit function's self information, which is composed of: + // 1) FunctionEntryPc + // 2) FunctionKind: Whether the function is indirect target, and if so, + // whether its type id is known. + // 3) FunctionTypeId: Emit only when the function is an indirect target + // and its type id is known. + + // Emit function entry pc. + const MCSymbol *FunctionSymbol = getFunctionBegin(); + OutStreamer->emitSymbolValue(FunctionSymbol, TM.getProgramPointerSize()); + + // If this function has external linkage or has its address taken and + // it is not a callback, then anything could call it. + const Function &F = MF.getFunction(); + bool IsIndirectTarget = + !F.hasLocalLinkage() || F.hasAddressTaken(nullptr, + /*IgnoreCallbackUses=*/true, + /*IgnoreAssumeLikeCalls=*/true, + /*IgnoreLLVMUsed=*/false); + + // FIXME: FunctionKind takes a few values but emitted as a 64-bit value. + // Can be optimized to occupy 2 bits instead. + // Emit function kind, and type id if available. + if (!IsIndirectTarget) { + OutStreamer->emitInt64( + static_cast<uint64_t>(FunctionInfo::FunctionKind::NOT_INDIRECT_TARGET)); + } else { + if (const auto *TypeId = extractNumericCGTypeId(F)) { + OutStreamer->emitInt64(static_cast<uint64_t>( + FunctionInfo::FunctionKind::INDIRECT_TARGET_KNOWN_TID)); + OutStreamer->emitInt64(TypeId->getZExtValue()); + } else { + OutStreamer->emitInt64(static_cast<uint64_t>( + FunctionInfo::FunctionKind::INDIRECT_TARGET_UNKNOWN_TID)); + } + } + + // Emit callsite labels, where each element is a pair of type id and + // indirect callsite pc. + const auto &CallSiteLabels = FuncInfo.CallSiteLabels; + OutStreamer->emitInt64(CallSiteLabels.size()); + for (const auto &[TypeId, Label] : CallSiteLabels) { + OutStreamer->emitInt64(TypeId); + OutStreamer->emitSymbolValue(Label, TM.getProgramPointerSize()); + } + FuncInfo.CallSiteLabels.clear(); + + OutStreamer->popSection(); +} + void AsmPrinter::emitPCSectionsLabel(const MachineFunction &MF, const MDNode &MD) { MCSymbol *S = MF.getContext().createTempSymbol("pcsection"); @@ -1784,6 +1866,23 @@ static StringRef getMIMnemonic(const MachineInstr &MI, MCStreamer &Streamer) { return Name; } +void AsmPrinter::emitIndirectCalleeLabels( + FunctionInfo &FuncInfo, + const MachineFunction::CallSiteInfoMap &CallSitesInfoMap, + const MachineInstr &MI) { + // Only indirect calls have type identifiers set. + const auto &CallSiteInfo = CallSitesInfoMap.find(&MI); + if (CallSiteInfo == CallSitesInfoMap.end()) + return; + + for (ConstantInt *CalleeTypeId : CallSiteInfo->second.CalleeTypeIds) { + MCSymbol *S = MF->getContext().createTempSymbol(); + OutStreamer->emitLabel(S); + uint64_t CalleeTypeIdVal = CalleeTypeId->getZExtValue(); + FuncInfo.CallSiteLabels.emplace_back(CalleeTypeIdVal, S); + } +} + /// EmitFunctionBody - This method emits the body and trailer for a /// function. void AsmPrinter::emitFunctionBody() { @@ -1830,6 +1929,8 @@ void AsmPrinter::emitFunctionBody() { MBBSectionRanges[MF->front().getSectionID()] = MBBSectionRange{CurrentFnBegin, nullptr}; + FunctionInfo FuncInfo; + const auto &CallSitesInfoMap = MF->getCallSitesInfo(); for (auto &MBB : *MF) { // Print a label for the basic block. emitBasicBlockStart(MBB); @@ -1963,6 +2064,9 @@ void AsmPrinter::emitFunctionBody() { break; } + if (TM.Options.EmitCallGraphSection && MI.isCall()) + emitIndirectCalleeLabels(FuncInfo, CallSitesInfoMap, MI); + // If there is a post-instruction symbol, emit a label for it here. if (MCSymbol *S = MI.getPostInstrSymbol()) OutStreamer->emitLabel(S); @@ -2142,6 +2246,9 @@ void AsmPrinter::emitFunctionBody() { // Emit section containing stack size metadata. emitStackSizeSection(*MF); + // Emit section containing call graph metadata. + emitCallGraphSection(*MF, FuncInfo); + // Emit .su file containing function stack size information. emitStackUsage(*MF); @@ -2841,6 +2948,7 @@ void AsmPrinter::SetupMachineFunction(MachineFunction &MF) { F.hasFnAttribute("xray-instruction-threshold") || needFuncLabels(MF, *this) || NeedsLocalForSize || MF.getTarget().Options.EmitStackSizeSection || + MF.getTarget().Options.EmitCallGraphSection || MF.getTarget().Options.BBAddrMap) { CurrentFnBegin = createTempSymbol("func_begin"); if (NeedsLocalForSize) diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp index 60d42e0..ec40f6a 100644 --- a/llvm/lib/CodeGen/MachineFunction.cpp +++ b/llvm/lib/CodeGen/MachineFunction.cpp @@ -698,6 +698,26 @@ bool MachineFunction::needsFrameMoves() const { !F.getParent()->debug_compile_units().empty(); } +MachineFunction::CallSiteInfo::CallSiteInfo(const CallBase &CB) { + // Numeric callee_type ids are only for indirect calls. + if (!CB.isIndirectCall()) + return; + + MDNode *CalleeTypeList = CB.getMetadata(LLVMContext::MD_callee_type); + if (!CalleeTypeList) + return; + + for (const MDOperand &Op : CalleeTypeList->operands()) { + MDNode *TypeMD = cast<MDNode>(Op); + MDString *TypeIdStr = cast<MDString>(TypeMD->getOperand(1)); + // Compute numeric type id from generalized type id string + uint64_t TypeIdVal = MD5Hash(TypeIdStr->getString()); + IntegerType *Int64Ty = Type::getInt64Ty(CB.getContext()); + CalleeTypeIds.push_back( + ConstantInt::get(Int64Ty, TypeIdVal, /*IsSigned=*/false)); + } +} + namespace llvm { template<> diff --git a/llvm/lib/CodeGen/MachineScheduler.cpp b/llvm/lib/CodeGen/MachineScheduler.cpp index 9d5c39c..c6fa8f4 100644 --- a/llvm/lib/CodeGen/MachineScheduler.cpp +++ b/llvm/lib/CodeGen/MachineScheduler.cpp @@ -3676,8 +3676,8 @@ void GenericScheduler::initialize(ScheduleDAGMI *dag) { TopCand.SU = nullptr; BotCand.SU = nullptr; - TopCluster = nullptr; - BotCluster = nullptr; + TopClusterID = InvalidClusterId; + BotClusterID = InvalidClusterId; } /// Initialize the per-region scheduling policy. @@ -3988,10 +3988,14 @@ bool GenericScheduler::tryCandidate(SchedCandidate &Cand, // This is a best effort to set things up for a post-RA pass. Optimizations // like generating loads of multiple registers should ideally be done within // the scheduler pass by combining the loads during DAG postprocessing. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; @@ -4251,24 +4255,30 @@ void GenericScheduler::reschedulePhysReg(SUnit *SU, bool isTop) { void GenericScheduler::schedNode(SUnit *SU, bool IsTopNode) { if (IsTopNode) { SU->TopReadyCycle = std::max(SU->TopReadyCycle, Top.getCurrCycle()); - TopCluster = DAG->getCluster(SU->ParentClusterIdx); - LLVM_DEBUG(if (TopCluster) { - dbgs() << " Top Cluster: "; - for (auto *N : *TopCluster) - dbgs() << N->NodeNum << '\t'; - dbgs() << '\n'; + TopClusterID = SU->ParentClusterIdx; + LLVM_DEBUG({ + if (TopClusterID != InvalidClusterId) { + ClusterInfo *TopCluster = DAG->getCluster(TopClusterID); + dbgs() << " Top Cluster: "; + for (auto *N : *TopCluster) + dbgs() << N->NodeNum << '\t'; + dbgs() << '\n'; + } }); Top.bumpNode(SU); if (SU->hasPhysRegUses) reschedulePhysReg(SU, true); } else { SU->BotReadyCycle = std::max(SU->BotReadyCycle, Bot.getCurrCycle()); - BotCluster = DAG->getCluster(SU->ParentClusterIdx); - LLVM_DEBUG(if (BotCluster) { - dbgs() << " Bot Cluster: "; - for (auto *N : *BotCluster) - dbgs() << N->NodeNum << '\t'; - dbgs() << '\n'; + BotClusterID = SU->ParentClusterIdx; + LLVM_DEBUG({ + if (BotClusterID != InvalidClusterId) { + ClusterInfo *BotCluster = DAG->getCluster(BotClusterID); + dbgs() << " Bot Cluster: "; + for (auto *N : *BotCluster) + dbgs() << N->NodeNum << '\t'; + dbgs() << '\n'; + } }); Bot.bumpNode(SU); if (SU->hasPhysRegDefs) @@ -4306,8 +4316,8 @@ void PostGenericScheduler::initialize(ScheduleDAGMI *Dag) { if (!Bot.HazardRec) { Bot.HazardRec = DAG->TII->CreateTargetMIHazardRecognizer(Itin, DAG); } - TopCluster = nullptr; - BotCluster = nullptr; + TopClusterID = InvalidClusterId; + BotClusterID = InvalidClusterId; } void PostGenericScheduler::initPolicy(MachineBasicBlock::iterator Begin, @@ -4373,10 +4383,14 @@ bool PostGenericScheduler::tryCandidate(SchedCandidate &Cand, return TryCand.Reason != NoCand; // Keep clustered nodes together. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; // Avoid critical resource consumption and balance the schedule. @@ -4575,11 +4589,11 @@ SUnit *PostGenericScheduler::pickNode(bool &IsTopNode) { void PostGenericScheduler::schedNode(SUnit *SU, bool IsTopNode) { if (IsTopNode) { SU->TopReadyCycle = std::max(SU->TopReadyCycle, Top.getCurrCycle()); - TopCluster = DAG->getCluster(SU->ParentClusterIdx); + TopClusterID = SU->ParentClusterIdx; Top.bumpNode(SU); } else { SU->BotReadyCycle = std::max(SU->BotReadyCycle, Bot.getCurrCycle()); - BotCluster = DAG->getCluster(SU->ParentClusterIdx); + BotClusterID = SU->ParentClusterIdx; Bot.bumpNode(SU); } } diff --git a/llvm/lib/CodeGen/RegAllocBase.cpp b/llvm/lib/CodeGen/RegAllocBase.cpp index 69b9291..2400a1f 100644 --- a/llvm/lib/CodeGen/RegAllocBase.cpp +++ b/llvm/lib/CodeGen/RegAllocBase.cpp @@ -178,10 +178,8 @@ void RegAllocBase::cleanupFailedVReg(Register FailedReg, MCRegister PhysReg, for (MCRegAliasIterator Aliases(PhysReg, TRI, true); Aliases.isValid(); ++Aliases) { for (MachineOperand &MO : MRI->reg_operands(*Aliases)) { - if (MO.readsReg()) { + if (MO.readsReg()) MO.setIsUndef(true); - LIS->removeAllRegUnitsForPhysReg(MO.getReg()); - } } } } diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index a43020e..5989c1d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -331,6 +331,11 @@ namespace { return CombineTo(N, To, 2, AddTo); } + SDValue CombineTo(SDNode *N, SmallVectorImpl<SDValue> *To, + bool AddTo = true) { + return CombineTo(N, To->data(), To->size(), AddTo); + } + void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO); private: @@ -541,6 +546,7 @@ namespace { SDValue visitEXTRACT_VECTOR_ELT(SDNode *N); SDValue visitBUILD_VECTOR(SDNode *N); SDValue visitCONCAT_VECTORS(SDNode *N); + SDValue visitVECTOR_INTERLEAVE(SDNode *N); SDValue visitEXTRACT_SUBVECTOR(SDNode *N); SDValue visitVECTOR_SHUFFLE(SDNode *N); SDValue visitSCALAR_TO_VECTOR(SDNode *N); @@ -2021,6 +2027,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N); case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N); case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N); + case ISD::VECTOR_INTERLEAVE: return visitVECTOR_INTERLEAVE(N); case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N); case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N); case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N); @@ -4100,18 +4107,17 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // (sub x, ([v]select (uge x, y), y, 0)) -> (umin x, (sub x, y)) if (N1.hasOneUse() && hasUMin(VT)) { SDValue Y; - if (sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETULT)), - m_Zero(), m_Deferred(Y))) || - sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETUGE)), - m_Deferred(Y), m_Zero())) || - sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETULT)), - m_Zero(), m_Deferred(Y))) || - sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETUGE)), - m_Deferred(Y), m_Zero()))) + auto MS0 = m_Specific(N0); + auto MVY = m_Value(Y); + auto MZ = m_Zero(); + auto MCC1 = m_SpecificCondCode(ISD::SETULT); + auto MCC2 = m_SpecificCondCode(ISD::SETUGE); + + if (sd_match(N1, m_SelectCCLike(MS0, MVY, MZ, m_Deferred(Y), MCC1)) || + sd_match(N1, m_SelectCCLike(MS0, MVY, m_Deferred(Y), MZ, MCC2)) || + sd_match(N1, m_VSelect(m_SetCC(MS0, MVY, MCC1), MZ, m_Deferred(Y))) || + sd_match(N1, m_VSelect(m_SetCC(MS0, MVY, MCC2), m_Deferred(Y), MZ))) + return DAG.getNode(ISD::UMIN, DL, VT, N0, DAG.getNode(ISD::SUB, DL, VT, N0, Y)); } @@ -10616,6 +10622,19 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { return DAG.getVScale(DL, VT, C0 << C1); } + SDValue X; + APInt VS0; + + // fold (shl (X * vscale(VS0)), C1) -> (X * vscale(VS0 << C1)) + if (N1C && sd_match(N0, m_Mul(m_Value(X), m_VScale(m_ConstInt(VS0))))) { + SDNodeFlags Flags; + Flags.setNoUnsignedWrap(N->getFlags().hasNoUnsignedWrap() && + N0->getFlags().hasNoUnsignedWrap()); + + SDValue VScale = DAG.getVScale(DL, VT, VS0 << N1C->getAPIntValue()); + return DAG.getNode(ISD::MUL, DL, VT, X, VScale, Flags); + } + // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)). APInt ShlVal; if (N0.getOpcode() == ISD::STEP_VECTOR && @@ -25282,6 +25301,28 @@ static SDValue combineConcatVectorOfShuffleAndItsOperands( return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask); } +static SDValue combineConcatVectorOfSplats(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI, + bool LegalTypes, + bool LegalOperations) { + EVT VT = N->getValueType(0); + + // Post-legalization we can only create wider SPLAT_VECTOR operations if both + // the type and operation is legal. The Hexagon target has custom + // legalization for SPLAT_VECTOR that splits the operation into two parts and + // concatenates them. Therefore, custom lowering must also be rejected in + // order to avoid an infinite loop. + if ((LegalTypes && !TLI.isTypeLegal(VT)) || + (LegalOperations && !TLI.isOperationLegal(ISD::SPLAT_VECTOR, VT))) + return SDValue(); + + SDValue Op0 = N->getOperand(0); + if (!llvm::all_equal(N->op_values()) || Op0.getOpcode() != ISD::SPLAT_VECTOR) + return SDValue(); + + return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, Op0.getOperand(0)); +} + SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // If we only have one input vector, we don't need to do any concatenation. if (N->getNumOperands() == 1) @@ -25405,6 +25446,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return DAG.getBuildVector(VT, SDLoc(N), Opnds); } + if (SDValue V = + combineConcatVectorOfSplats(N, DAG, TLI, LegalTypes, LegalOperations)) + return V; + // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR. // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...). if (SDValue V = combineConcatVectorOfScalars(N, DAG)) @@ -25473,6 +25518,21 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitVECTOR_INTERLEAVE(SDNode *N) { + // Check to see if all operands are identical. + if (!llvm::all_equal(N->op_values())) + return SDValue(); + + // Check to see if the identical operand is a splat. + if (!DAG.isSplatValue(N->getOperand(0))) + return SDValue(); + + // interleave splat(X), splat(X).... --> splat(X), splat(X).... + SmallVector<SDValue, 4> Ops; + Ops.append(N->op_values().begin(), N->op_values().end()); + return CombineTo(N, &Ops); +} + // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find // if the subvector can be sourced for free. static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 02d1100..f41b6eb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -12782,7 +12782,7 @@ bool SDNode::areOnlyUsersOf(ArrayRef<const SDNode *> Nodes, const SDNode *N) { return Seen; } -/// isOperand - Return true if this node is an operand of N. +/// Return true if the referenced return value is an operand of N. bool SDValue::isOperandOf(const SDNode *N) const { return is_contained(N->op_values(), *this); } diff --git a/llvm/lib/CodeGen/TailDuplicator.cpp b/llvm/lib/CodeGen/TailDuplicator.cpp index a88c57f..5d720fb 100644 --- a/llvm/lib/CodeGen/TailDuplicator.cpp +++ b/llvm/lib/CodeGen/TailDuplicator.cpp @@ -604,12 +604,21 @@ bool TailDuplicator::shouldTailDuplicate(bool IsSimple, bool HasComputedGoto = false; if (!TailBB.empty()) { HasIndirectbr = TailBB.back().isIndirectBranch(); - HasComputedGoto = TailBB.terminatorIsComputedGoto(); + HasComputedGoto = TailBB.terminatorIsComputedGotoWithSuccessors(); } if (HasIndirectbr && PreRegAlloc) MaxDuplicateCount = TailDupIndirectBranchSize; + // Allow higher limits when the block has computed-gotos and running after + // register allocation. NB. This basically unfactors computed gotos that were + // factored early on in the compilation process to speed up edge based data + // flow. If we do not unfactor them again, it can seriously pessimize code + // with many computed jumps in the source code, such as interpreters. + // Therefore we do not restrict the computed gotos. + if (HasComputedGoto && !PreRegAlloc) + MaxDuplicateCount = std::max(MaxDuplicateCount, 10u); + // Check the instructions in the block to determine whether tail-duplication // is invalid or unlikely to be profitable. unsigned InstrCount = 0; @@ -663,12 +672,7 @@ bool TailDuplicator::shouldTailDuplicate(bool IsSimple, // Duplicating a BB which has both multiple predecessors and successors will // may cause huge amount of PHI nodes. If we want to remove this limitation, // we have to address https://github.com/llvm/llvm-project/issues/78578. - // NB. This basically unfactors computed gotos that were factored early on in - // the compilation process to speed up edge based data flow. If we do not - // unfactor them again, it can seriously pessimize code with many computed - // jumps in the source code, such as interpreters. Therefore we do not - // restrict the computed gotos. - if (!HasComputedGoto && TailBB.pred_size() > TailDupPredSize && + if (PreRegAlloc && TailBB.pred_size() > TailDupPredSize && TailBB.succ_size() > TailDupSuccSize) { // If TailBB or any of its successors contains a phi, we may have to add a // large number of additional phis with additional incoming values. diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp index 18d6bbc..705e046e 100644 --- a/llvm/lib/CodeGen/TargetInstrInfo.cpp +++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp @@ -1406,7 +1406,7 @@ void TargetInstrInfo::reassociateOps( const MCInstrDesc &MCID, Register DestReg) { return MachineInstrBuilder( MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true)) - .setPCSections(MIMD.getPCSections()) + .copyMIMetadata(MIMD) .addReg(DestReg, RegState::Define); }; diff --git a/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp b/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp index 222dc88..6ddb12b 100644 --- a/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp +++ b/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp @@ -413,6 +413,117 @@ static bool isTlsAddressCode(uint8_t DW_OP_Code) { DW_OP_Code == dwarf::DW_OP_GNU_push_tls_address; } +static void constructSeqOffsettoOrigRowMapping( + CompileUnit &Unit, const DWARFDebugLine::LineTable <, + DenseMap<uint64_t, unsigned> &SeqOffToOrigRow) { + + // Use std::map for ordered iteration. + std::map<uint64_t, unsigned> LineTableMapping; + + // First, trust the sequences that the DWARF parser did identify. + for (const DWARFDebugLine::Sequence &Seq : LT.Sequences) + LineTableMapping[Seq.StmtSeqOffset] = Seq.FirstRowIndex; + + // Second, manually find sequence boundaries and match them to the + // sorted attributes to handle sequences the parser might have missed. + auto StmtAttrs = Unit.getStmtSeqListAttributes(); + llvm::sort(StmtAttrs, [](const PatchLocation &A, const PatchLocation &B) { + return A.get() < B.get(); + }); + + std::vector<unsigned> SeqStartRows; + SeqStartRows.push_back(0); + for (auto [I, Row] : llvm::enumerate(ArrayRef(LT.Rows).drop_back())) + if (Row.EndSequence) + SeqStartRows.push_back(I + 1); + + // While SeqOffToOrigRow parsed from CU could be the ground truth, + // e.g. + // + // SeqOff Row + // 0x08 9 + // 0x14 15 + // + // The StmtAttrs and SeqStartRows may not match perfectly, e.g. + // + // StmtAttrs SeqStartRows + // 0x04 3 + // 0x08 5 + // 0x10 9 + // 0x12 11 + // 0x14 15 + // + // In this case, we don't want to assign 5 to 0x08, since we know 0x08 + // maps to 9. If we do a dummy 1:1 mapping 0x10 will be mapped to 9 + // which is incorrect. The expected behavior is ignore 5, realign the + // table based on the result from the line table: + // + // StmtAttrs SeqStartRows + // 0x04 3 + // -- 5 + // 0x08 9 <- LineTableMapping ground truth + // 0x10 11 + // 0x12 -- + // 0x14 15 <- LineTableMapping ground truth + + ArrayRef StmtAttrsRef(StmtAttrs); + ArrayRef SeqStartRowsRef(SeqStartRows); + + // Dummy last element to make sure StmtAttrsRef and SeqStartRowsRef always + // run out first. + constexpr uint64_t DummyKey = UINT64_MAX; + constexpr unsigned DummyVal = UINT32_MAX; + LineTableMapping[DummyKey] = DummyVal; + + for (auto [NextSeqOff, NextRow] : LineTableMapping) { + // Explict capture to avoid capturing structured bindings and make C++17 + // happy. + auto StmtAttrSmallerThanNext = [N = NextSeqOff](const PatchLocation &SA) { + return SA.get() < N; + }; + auto SeqStartSmallerThanNext = [N = NextRow](const unsigned &Row) { + return Row < N; + }; + // If both StmtAttrs and SeqStartRows points to value not in + // the LineTableMapping yet, we do a dummy one to one mapping and + // move the pointer. + while (!StmtAttrsRef.empty() && !SeqStartRowsRef.empty() && + StmtAttrSmallerThanNext(StmtAttrsRef.front()) && + SeqStartSmallerThanNext(SeqStartRowsRef.front())) { + SeqOffToOrigRow[StmtAttrsRef.consume_front().get()] = + SeqStartRowsRef.consume_front(); + } + // One of the pointer points to the value at or past Next in the + // LineTableMapping, We move the pointer to re-align with the + // LineTableMapping + StmtAttrsRef = StmtAttrsRef.drop_while(StmtAttrSmallerThanNext); + SeqStartRowsRef = SeqStartRowsRef.drop_while(SeqStartSmallerThanNext); + // Use the LineTableMapping's result as the ground truth and move + // on. + if (NextSeqOff != DummyKey) { + SeqOffToOrigRow[NextSeqOff] = NextRow; + } + // Move the pointers if they are pointed at Next. + // It is possible that they point to later entries in LineTableMapping. + // Therefore we only increment the pointers after we validate they are + // pointing to the `Next` entry. e.g. + // + // LineTableMapping + // SeqOff Row + // 0x08 9 <- NextSeqOff/NextRow + // 0x14 15 + // + // StmtAttrs SeqStartRows + // 0x14 13 <- StmtAttrsRef.front() / SeqStartRowsRef.front() + // 0x16 15 + // -- 17 + if (!StmtAttrsRef.empty() && StmtAttrsRef.front().get() == NextSeqOff) + StmtAttrsRef.consume_front(); + if (!SeqStartRowsRef.empty() && SeqStartRowsRef.front() == NextRow) + SeqStartRowsRef.consume_front(); + } +} + std::pair<bool, std::optional<int64_t>> DWARFLinker::getVariableRelocAdjustment(AddressesMap &RelocMgr, const DWARFDie &DIE) { @@ -2297,8 +2408,12 @@ void DWARFLinker::DIECloner::generateLineTableForUnit(CompileUnit &Unit) { // Create a map of stmt sequence offsets to original row indices. DenseMap<uint64_t, unsigned> SeqOffToOrigRow; - for (const DWARFDebugLine::Sequence &Seq : LT->Sequences) - SeqOffToOrigRow[Seq.StmtSeqOffset] = Seq.FirstRowIndex; + // The DWARF parser's discovery of sequences can be incomplete. To + // ensure all DW_AT_LLVM_stmt_sequence attributes can be patched, we + // build a map from both the parser's results and a manual + // reconstruction. + if (!LT->Rows.empty()) + constructSeqOffsettoOrigRowMapping(Unit, *LT, SeqOffToOrigRow); // Create a map of original row indices to new row indices. DenseMap<size_t, size_t> OrigRowToNewRow; diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt index 5343469..3d22577 100644 --- a/llvm/lib/Frontend/HLSL/CMakeLists.txt +++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMFrontendHLSL CBuffer.cpp + HLSLBinding.cpp HLSLResource.cpp HLSLRootSignature.cpp RootSignatureMetadata.cpp diff --git a/llvm/lib/Frontend/HLSL/HLSLBinding.cpp b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp new file mode 100644 index 0000000..d581311 --- /dev/null +++ b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp @@ -0,0 +1,142 @@ +//===- HLSLBinding.cpp - Representation for resource bindings in HLSL -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLBinding.h" +#include "llvm/ADT/STLExtras.h" + +using namespace llvm; +using namespace hlsl; + +std::optional<uint32_t> +BindingInfo::findAvailableBinding(dxil::ResourceClass RC, uint32_t Space, + int32_t Size) { + BindingSpaces &BS = getBindingSpaces(RC); + RegisterSpace &RS = BS.getOrInsertSpace(Space); + return RS.findAvailableBinding(Size); +} + +BindingInfo::RegisterSpace & +BindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) { + for (auto It = Spaces.begin(), End = Spaces.end(); It != End; ++It) { + if (It->Space == Space) + return *It; + if (It->Space < Space) + continue; + return *Spaces.insert(It, Space); + } + return Spaces.emplace_back(Space); +} + +std::optional<uint32_t> +BindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) { + assert((Size == -1 || Size > 0) && "invalid size"); + + if (FreeRanges.empty()) + return std::nullopt; + + // unbounded array + if (Size == -1) { + BindingRange &Last = FreeRanges.back(); + if (Last.UpperBound != ~0u) + // this space is already occupied by an unbounded array + return std::nullopt; + uint32_t RegSlot = Last.LowerBound; + FreeRanges.pop_back(); + return RegSlot; + } + + // single resource or fixed-size array + for (BindingRange &R : FreeRanges) { + // compare the size as uint64_t to prevent overflow for range (0, ~0u) + if ((uint64_t)R.UpperBound - R.LowerBound + 1 < (uint64_t)Size) + continue; + uint32_t RegSlot = R.LowerBound; + // This might create a range where (LowerBound == UpperBound + 1). When + // that happens, the next time this function is called the range will + // skipped over by the check above (at this point Size is always > 0). + R.LowerBound += Size; + return RegSlot; + } + + return std::nullopt; +} + +BindingInfo BindingInfoBuilder::calculateBindingInfo( + llvm::function_ref<void(const BindingInfoBuilder &Builder, + const Binding &Overlapping)> + ReportOverlap) { + // sort all the collected bindings + llvm::stable_sort(Bindings); + + // remove duplicates + Binding *NewEnd = llvm::unique(Bindings); + if (NewEnd != Bindings.end()) + Bindings.erase(NewEnd); + + BindingInfo Info; + + // Go over the sorted bindings and build up lists of free register ranges + // for each binding type and used spaces. Bindings are sorted by resource + // class, space, and lower bound register slot. + BindingInfo::BindingSpaces *BS = + &Info.getBindingSpaces(dxil::ResourceClass::SRV); + for (const Binding &B : Bindings) { + if (BS->RC != B.RC) + // move to the next resource class spaces + BS = &Info.getBindingSpaces(B.RC); + + BindingInfo::RegisterSpace *S = BS->Spaces.empty() + ? &BS->Spaces.emplace_back(B.Space) + : &BS->Spaces.back(); + assert(S->Space <= B.Space && "bindings not sorted correctly?"); + if (B.Space != S->Space) + // add new space + S = &BS->Spaces.emplace_back(B.Space); + + // The space is full - there are no free slots left, or the rest of the + // slots are taken by an unbounded array. Report the overlapping to the + // caller. + if (S->FreeRanges.empty() || S->FreeRanges.back().UpperBound < ~0u) { + ReportOverlap(*this, B); + continue; + } + // adjust the last free range lower bound, split it in two, or remove it + BindingInfo::BindingRange &LastFreeRange = S->FreeRanges.back(); + if (LastFreeRange.LowerBound == B.LowerBound) { + if (B.UpperBound < ~0u) + LastFreeRange.LowerBound = B.UpperBound + 1; + else + S->FreeRanges.pop_back(); + } else if (LastFreeRange.LowerBound < B.LowerBound) { + LastFreeRange.UpperBound = B.LowerBound - 1; + if (B.UpperBound < ~0u) + S->FreeRanges.emplace_back(B.UpperBound + 1, ~0u); + } else { + // We don't have room here. Report the overlapping binding to the caller + // and mark any extra space this binding would use as unavailable. + ReportOverlap(*this, B); + if (B.UpperBound < ~0u) + LastFreeRange.LowerBound = + std::max(LastFreeRange.LowerBound, B.UpperBound + 1); + else + S->FreeRanges.pop_back(); + } + } + + return Info; +} + +const BindingInfoBuilder::Binding &BindingInfoBuilder::findOverlapping( + const BindingInfoBuilder::Binding &ReportedBinding) const { + for (const BindingInfoBuilder::Binding &Other : Bindings) + if (ReportedBinding.LowerBound <= Other.UpperBound && + Other.LowerBound <= ReportedBinding.UpperBound) + return Other; + + llvm_unreachable("Searching for overlap for binding that does not overlap"); +} diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 53f5934..48ff1ca 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -13,15 +13,21 @@ #include "llvm/Frontend/HLSL/RootSignatureMetadata.h" #include "llvm/Frontend/HLSL/RootSignatureValidations.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/ScopedPrinter.h" +using namespace llvm; + namespace llvm { namespace hlsl { namespace rootsig { +char GenericRSMetadataError::ID; +char InvalidRSMetadataFormat::ID; +char InvalidRSMetadataValue::ID; +template <typename T> char RootSignatureValidationError<T>::ID; + static std::optional<uint32_t> extractMdIntValue(MDNode *Node, unsigned int OpId) { if (auto *CI = @@ -45,19 +51,6 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node, return NodeText->getString(); } -static bool reportError(LLVMContext *Ctx, Twine Message, - DiagnosticSeverity Severity = DS_Error) { - Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity)); - return true; -} - -static bool reportValueError(LLVMContext *Ctx, Twine ParamName, - uint32_t Value) { - Ctx->diagnose(DiagnosticInfoGeneric( - "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error)); - return true; -} - static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = { {"CBV", dxil::ResourceClass::CBuffer}, {"SRV", dxil::ResourceClass::SRV}, @@ -120,7 +113,7 @@ MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) { IRBuilder<> Builder(Ctx); Metadata *Operands[] = { MDString::get(Ctx, "RootFlags"), - ConstantAsMetadata::get(Builder.getInt32(llvm::to_underlying(Flags))), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Flags))), }; return MDNode::get(Ctx, Operands); } @@ -130,7 +123,7 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { Metadata *Operands[] = { MDString::get(Ctx, "RootConstants"), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Constants.Visibility))), + Builder.getInt32(to_underlying(Constants.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Constants.Space)), ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)), @@ -140,18 +133,18 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { IRBuilder<> Builder(Ctx); - std::optional<StringRef> ResName = getResourceName( - dxil::ResourceClass(llvm::to_underlying(Descriptor.Type))); + std::optional<StringRef> ResName = + getResourceName(dxil::ResourceClass(to_underlying(Descriptor.Type))); assert(ResName && "Provided an invalid Resource Class"); - llvm::SmallString<7> Name({"Root", *ResName}); + SmallString<7> Name({"Root", *ResName}); Metadata *Operands[] = { MDString::get(Ctx, Name), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))), + Builder.getInt32(to_underlying(Descriptor.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Flags))), + Builder.getInt32(to_underlying(Descriptor.Flags))), }; return MDNode::get(Ctx, Operands); } @@ -162,7 +155,7 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { // Set the mandatory arguments TableOperands.push_back(MDString::get(Ctx, "DescriptorTable")); TableOperands.push_back(ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Table.Visibility)))); + Builder.getInt32(to_underlying(Table.Visibility)))); // Remaining operands are references to the table's clauses. The in-memory // representation of the Root Elements created from parsing will ensure that @@ -182,7 +175,7 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause( const DescriptorTableClause &Clause) { IRBuilder<> Builder(Ctx); std::optional<StringRef> ResName = - getResourceName(dxil::ResourceClass(llvm::to_underlying(Clause.Type))); + getResourceName(dxil::ResourceClass(to_underlying(Clause.Type))); assert(ResName && "Provided an invalid Resource Class"); Metadata *Operands[] = { MDString::get(Ctx, *ResName), @@ -190,8 +183,7 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause( ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Clause.Flags))), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Clause.Flags))), }; return MDNode::get(Ctx, Operands); } @@ -200,108 +192,102 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { IRBuilder<> Builder(Ctx); Metadata *Operands[] = { MDString::get(Ctx, "StaticSampler"), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Sampler.Filter))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Filter))), + Builder.getInt32(to_underlying(Sampler.AddressU))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressU))), + Builder.getInt32(to_underlying(Sampler.AddressV))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressV))), + Builder.getInt32(to_underlying(Sampler.AddressW))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressW))), - ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), - Sampler.MipLODBias)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MipLODBias)), ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))), + Builder.getInt32(to_underlying(Sampler.CompFunc))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))), + Builder.getInt32(to_underlying(Sampler.BorderColor))), ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MinLOD)), ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MaxLOD)), ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Visibility))), + Builder.getInt32(to_underlying(Sampler.Visibility))), }; return MDNode::get(Ctx, Operands); } -bool MetadataParser::parseRootFlags(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *RootFlagNode) { - +Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD, + MDNode *RootFlagNode) { if (RootFlagNode->getNumOperands() != 2) - return reportError(Ctx, "Invalid format for RootFlag Element"); + return make_error<InvalidRSMetadataFormat>("RootFlag Element"); if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1)) RSD.Flags = *Val; else - return reportError(Ctx, "Invalid value for RootFlag"); + return make_error<InvalidRSMetadataValue>("RootFlag"); - return false; + return Error::success(); } -bool MetadataParser::parseRootConstants(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *RootConstantNode) { - +Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, + MDNode *RootConstantNode) { if (RootConstantNode->getNumOperands() != 5) - return reportError(Ctx, "Invalid format for RootConstants Element"); + return make_error<InvalidRSMetadataFormat>("RootConstants Element"); dxbc::RTS0::v1::RootParameterHeader Header; // The parameter offset doesn't matter here - we recalculate it during // serialization Header.ParameterOffset = 0; - Header.ParameterType = - llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); + Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); dxbc::RTS0::v1::RootConstants Constants; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) Constants.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) Constants.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) Constants.Num32BitValues = *Val; else - return reportError(Ctx, "Invalid value for Num32BitValues"); + return make_error<InvalidRSMetadataValue>("Num32BitValues"); RSD.ParametersContainer.addParameter(Header, Constants); - return false; + return Error::success(); } -bool MetadataParser::parseRootDescriptors( - LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, - MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) { +Error MetadataParser::parseRootDescriptors( + mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode, + RootSignatureElementKind ElementKind) { assert(ElementKind == RootSignatureElementKind::SRV || ElementKind == RootSignatureElementKind::UAV || ElementKind == RootSignatureElementKind::CBV && "parseRootDescriptors should only be called with RootDescriptor " "element kind."); if (RootDescriptorNode->getNumOperands() != 5) - return reportError(Ctx, "Invalid format for Root Descriptor Element"); + return make_error<InvalidRSMetadataFormat>("Root Descriptor Element"); dxbc::RTS0::v1::RootParameterHeader Header; switch (ElementKind) { case RootSignatureElementKind::SRV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV); + Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV); break; case RootSignatureElementKind::UAV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV); + Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV); break; case RootSignatureElementKind::CBV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV); + Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV); break; default: llvm_unreachable("invalid Root Descriptor kind"); @@ -311,40 +297,38 @@ bool MetadataParser::parseRootDescriptors( if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); dxbc::RTS0::v2::RootDescriptor Descriptor; if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2)) Descriptor.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3)) Descriptor.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (RSD.Version == 1) { RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; + return Error::success(); } assert(RSD.Version > 1); if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4)) Descriptor.Flags = *Val; else - return reportError(Ctx, "Invalid value for Root Descriptor Flags"); + return make_error<InvalidRSMetadataValue>("Root Descriptor Flags"); RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; + return Error::success(); } -bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx, - mcdxbc::DescriptorTable &Table, - MDNode *RangeDescriptorNode) { - +Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, + MDNode *RangeDescriptorNode) { if (RangeDescriptorNode->getNumOperands() != 6) - return reportError(Ctx, "Invalid format for Descriptor Range"); + return make_error<InvalidRSMetadataFormat>("Descriptor Range"); dxbc::RTS0::v2::DescriptorRange Range; @@ -352,162 +336,161 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx, extractMdStringValue(RangeDescriptorNode, 0); if (!ElementText.has_value()) - return reportError(Ctx, "Descriptor Range, first element is not a string."); + return make_error<InvalidRSMetadataFormat>("Descriptor Range"); Range.RangeType = StringSwitch<uint32_t>(*ElementText) - .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV)) - .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV)) - .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV)) - .Case("Sampler", - llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)) + .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV)) + .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV)) + .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV)) + .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler)) .Default(~0U); if (Range.RangeType == ~0U) - return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText); + return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.", + RangeDescriptorNode); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1)) Range.NumDescriptors = *Val; else - return reportError(Ctx, "Invalid value for Number of Descriptor in Range"); + return make_error<GenericRSMetadataError>("Number of Descriptor in Range", + RangeDescriptorNode); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2)) Range.BaseShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for BaseShaderRegister"); + return make_error<InvalidRSMetadataValue>("BaseShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3)) Range.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4)) Range.OffsetInDescriptorsFromTableStart = *Val; else - return reportError(Ctx, - "Invalid value for OffsetInDescriptorsFromTableStart"); + return make_error<InvalidRSMetadataValue>( + "OffsetInDescriptorsFromTableStart"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5)) Range.Flags = *Val; else - return reportError(Ctx, "Invalid value for Descriptor Range Flags"); + return make_error<InvalidRSMetadataValue>("Descriptor Range Flags"); Table.Ranges.push_back(Range); - return false; + return Error::success(); } -bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *DescriptorTableNode) { +Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, + MDNode *DescriptorTableNode) { const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); if (NumOperands < 2) - return reportError(Ctx, "Invalid format for Descriptor Table"); + return make_error<InvalidRSMetadataFormat>("Descriptor Table"); dxbc::RTS0::v1::RootParameterHeader Header; if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); mcdxbc::DescriptorTable Table; Header.ParameterType = - llvm::to_underlying(dxbc::RootParameterType::DescriptorTable); + to_underlying(dxbc::RootParameterType::DescriptorTable); for (unsigned int I = 2; I < NumOperands; I++) { MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I)); if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); + return make_error<GenericRSMetadataError>( + "Missing Root Element Metadata Node.", DescriptorTableNode); - if (parseDescriptorRange(Ctx, Table, Element)) - return true; + if (auto Err = parseDescriptorRange(Table, Element)) + return Err; } RSD.ParametersContainer.addParameter(Header, Table); - return false; + return Error::success(); } -bool MetadataParser::parseStaticSampler(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *StaticSamplerNode) { +Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, + MDNode *StaticSamplerNode) { if (StaticSamplerNode->getNumOperands() != 14) - return reportError(Ctx, "Invalid format for Static Sampler"); + return make_error<InvalidRSMetadataFormat>("Static Sampler"); dxbc::RTS0::v1::StaticSampler Sampler; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1)) Sampler.Filter = *Val; else - return reportError(Ctx, "Invalid value for Filter"); + return make_error<InvalidRSMetadataValue>("Filter"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2)) Sampler.AddressU = *Val; else - return reportError(Ctx, "Invalid value for AddressU"); + return make_error<InvalidRSMetadataValue>("AddressU"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3)) Sampler.AddressV = *Val; else - return reportError(Ctx, "Invalid value for AddressV"); + return make_error<InvalidRSMetadataValue>("AddressV"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4)) Sampler.AddressW = *Val; else - return reportError(Ctx, "Invalid value for AddressW"); + return make_error<InvalidRSMetadataValue>("AddressW"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5)) Sampler.MipLODBias = *Val; else - return reportError(Ctx, "Invalid value for MipLODBias"); + return make_error<InvalidRSMetadataValue>("MipLODBias"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6)) Sampler.MaxAnisotropy = *Val; else - return reportError(Ctx, "Invalid value for MaxAnisotropy"); + return make_error<InvalidRSMetadataValue>("MaxAnisotropy"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7)) Sampler.ComparisonFunc = *Val; else - return reportError(Ctx, "Invalid value for ComparisonFunc "); + return make_error<InvalidRSMetadataValue>("ComparisonFunc"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8)) Sampler.BorderColor = *Val; else - return reportError(Ctx, "Invalid value for ComparisonFunc "); + return make_error<InvalidRSMetadataValue>("ComparisonFunc"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9)) Sampler.MinLOD = *Val; else - return reportError(Ctx, "Invalid value for MinLOD"); + return make_error<InvalidRSMetadataValue>("MinLOD"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10)) Sampler.MaxLOD = *Val; else - return reportError(Ctx, "Invalid value for MaxLOD"); + return make_error<InvalidRSMetadataValue>("MaxLOD"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11)) Sampler.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12)) Sampler.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13)) Sampler.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); RSD.StaticSamplers.push_back(Sampler); - return false; + return Error::success(); } -bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *Element) { +Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD, + MDNode *Element) { std::optional<StringRef> ElementText = extractMdStringValue(Element, 0); if (!ElementText.has_value()) - return reportError(Ctx, "Invalid format for Root Element"); + return make_error<InvalidRSMetadataFormat>("Root Element"); RootSignatureElementKind ElementKind = StringSwitch<RootSignatureElementKind>(*ElementText) @@ -523,79 +506,109 @@ bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, switch (ElementKind) { case RootSignatureElementKind::RootFlags: - return parseRootFlags(Ctx, RSD, Element); + return parseRootFlags(RSD, Element); case RootSignatureElementKind::RootConstants: - return parseRootConstants(Ctx, RSD, Element); + return parseRootConstants(RSD, Element); case RootSignatureElementKind::CBV: case RootSignatureElementKind::SRV: case RootSignatureElementKind::UAV: - return parseRootDescriptors(Ctx, RSD, Element, ElementKind); + return parseRootDescriptors(RSD, Element, ElementKind); case RootSignatureElementKind::DescriptorTable: - return parseDescriptorTable(Ctx, RSD, Element); + return parseDescriptorTable(RSD, Element); case RootSignatureElementKind::StaticSamplers: - return parseStaticSampler(Ctx, RSD, Element); + return parseStaticSampler(RSD, Element); case RootSignatureElementKind::Error: - return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText); + return make_error<GenericRSMetadataError>("Invalid Root Signature Element", + Element); } llvm_unreachable("Unhandled RootSignatureElementKind enum."); } -bool MetadataParser::validateRootSignature( - LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) { - if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) { - return reportValueError(Ctx, "Version", RSD.Version); +Error MetadataParser::validateRootSignature( + const mcdxbc::RootSignatureDesc &RSD) { + Error DeferredErrs = Error::success(); + if (!hlsl::rootsig::verifyVersion(RSD.Version)) { + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Version", RSD.Version)); } - if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) { - return reportValueError(Ctx, "RootFlags", RSD.Flags); + if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) { + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RootFlags", RSD.Flags)); } for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Info.Header.ShaderVisibility); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", Info.Header.ShaderVisibility)); assert(dxbc::isValidParameterType(Info.Header.ParameterType) && "Invalid value for ParameterType"); switch (Info.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): { + case to_underlying(dxbc::RootParameterType::CBV): + case to_underlying(dxbc::RootParameterType::UAV): + case to_underlying(dxbc::RootParameterType::SRV): { const dxbc::RTS0::v2::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location); - if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", - Descriptor.ShaderRegister); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace); + if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderRegister", Descriptor.ShaderRegister)); + + if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Descriptor.RegisterSpace)); if (RSD.Version > 1) { - if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) - return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags); + if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, + Descriptor.Flags)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RootDescriptorFlag", Descriptor.Flags)); } break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case to_underlying(dxbc::RootParameterType::DescriptorTable): { const mcdxbc::DescriptorTable &Table = RSD.ParametersContainer.getDescriptorTable(Info.Location); for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { - if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType)) - return reportValueError(Ctx, "RangeType", Range.RangeType); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace); - - if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) - return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors); - - if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( + if (!hlsl::rootsig::verifyRangeType(Range.RangeType)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RangeType", Range.RangeType)); + + if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Range.RegisterSpace)); + + if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "NumDescriptors", Range.NumDescriptors)); + + if (!hlsl::rootsig::verifyDescriptorRangeFlag( RSD.Version, Range.RangeType, Range.Flags)) - return reportValueError(Ctx, "DescriptorFlag", Range.Flags); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "DescriptorFlag", Range.Flags)); } break; } @@ -603,65 +616,108 @@ bool MetadataParser::validateRootSignature( } for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { - if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) - return reportValueError(Ctx, "Filter", Sampler.Filter); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU)) - return reportValueError(Ctx, "AddressU", Sampler.AddressU); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV)) - return reportValueError(Ctx, "AddressV", Sampler.AddressV); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW)) - return reportValueError(Ctx, "AddressW", Sampler.AddressW); - - if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) - return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias); - - if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) - return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy); - - if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) - return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc); - - if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) - return reportValueError(Ctx, "BorderColor", Sampler.BorderColor); - - if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD)) - return reportValueError(Ctx, "MinLOD", Sampler.MinLOD); - - if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) - return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD); - - if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace); + if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Filter", Sampler.Filter)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressU)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressU", Sampler.AddressU)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressV)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressV", Sampler.AddressV)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressW)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressW", Sampler.AddressW)); + + if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MipLODBias", Sampler.MipLODBias)); + + if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "MaxAnisotropy", Sampler.MaxAnisotropy)); + + if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ComparisonFunc", Sampler.ComparisonFunc)); + + if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "BorderColor", Sampler.BorderColor)); + + if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MinLOD", Sampler.MinLOD)); + + if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MaxLOD", Sampler.MaxLOD)); + + if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderRegister", Sampler.ShaderRegister)); + + if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Sampler.RegisterSpace)); if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Sampler.ShaderVisibility); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", Sampler.ShaderVisibility)); } - return false; + return DeferredErrs; } -bool MetadataParser::ParseRootSignature(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD) { - bool HasError = false; - - // Loop through the Root Elements of the root signature. +Expected<mcdxbc::RootSignatureDesc> +MetadataParser::ParseRootSignature(uint32_t Version) { + Error DeferredErrs = Error::success(); + mcdxbc::RootSignatureDesc RSD; + RSD.Version = Version; for (const auto &Operand : Root->operands()) { MDNode *Element = dyn_cast<MDNode>(Operand); if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); + return joinErrors(std::move(DeferredErrs), + make_error<GenericRSMetadataError>( + "Missing Root Element Metadata Node.", nullptr)); - HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) || - validateRootSignature(Ctx, RSD); + if (auto Err = parseRootSignatureElement(RSD, Element)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); } - return HasError; + if (auto Err = validateRootSignature(RSD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); + + if (DeferredErrs) + return std::move(DeferredErrs); + + return std::move(RSD); } } // namespace rootsig } // namespace hlsl diff --git a/llvm/lib/IR/DebugInfoMetadata.cpp b/llvm/lib/IR/DebugInfoMetadata.cpp index f16963d..f1d4549 100644 --- a/llvm/lib/IR/DebugInfoMetadata.cpp +++ b/llvm/lib/IR/DebugInfoMetadata.cpp @@ -1012,7 +1012,7 @@ DIDerivedType *DIDerivedType::getImpl( std::optional<DIDerivedType::PtrAuthData> DIDerivedType::getPtrAuthData() const { return getTag() == dwarf::DW_TAG_LLVM_ptrauth_type - ? std::optional<PtrAuthData>(PtrAuthData(SubclassData32)) + ? std::make_optional<PtrAuthData>(SubclassData32) : std::nullopt; } diff --git a/llvm/lib/MC/MCObjectFileInfo.cpp b/llvm/lib/MC/MCObjectFileInfo.cpp index 0069d12..393eed1 100644 --- a/llvm/lib/MC/MCObjectFileInfo.cpp +++ b/llvm/lib/MC/MCObjectFileInfo.cpp @@ -537,6 +537,8 @@ void MCObjectFileInfo::initELFMCObjectFileInfo(const Triple &T, bool Large) { EHFrameSection = Ctx->getELFSection(".eh_frame", EHSectionType, EHSectionFlags); + CallGraphSection = Ctx->getELFSection(".callgraph", ELF::SHT_PROGBITS, 0); + StackSizesSection = Ctx->getELFSection(".stack_sizes", ELF::SHT_PROGBITS, 0); PseudoProbeSection = Ctx->getELFSection(".pseudo_probe", DebugSecType, 0); @@ -1121,6 +1123,24 @@ MCSection *MCObjectFileInfo::getDwarfComdatSection(const char *Name, } MCSection * +MCObjectFileInfo::getCallGraphSection(const MCSection &TextSec) const { + if (Ctx->getObjectFileType() != MCContext::IsELF) + return CallGraphSection; + + const MCSectionELF &ElfSec = static_cast<const MCSectionELF &>(TextSec); + unsigned Flags = ELF::SHF_LINK_ORDER; + StringRef GroupName; + if (const MCSymbol *Group = ElfSec.getGroup()) { + GroupName = Group->getName(); + Flags |= ELF::SHF_GROUP; + } + + return Ctx->getELFSection(".callgraph", ELF::SHT_PROGBITS, Flags, 0, + GroupName, true, ElfSec.getUniqueID(), + cast<MCSymbolELF>(TextSec.getBeginSymbol())); +} + +MCSection * MCObjectFileInfo::getStackSizesSection(const MCSection &TextSec) const { if ((Ctx->getObjectFileType() != MCContext::IsELF) || Ctx->getTargetTriple().isPS4()) diff --git a/llvm/lib/Object/ELFObjectFile.cpp b/llvm/lib/Object/ELFObjectFile.cpp index 0919c6a..aff047c 100644 --- a/llvm/lib/Object/ELFObjectFile.cpp +++ b/llvm/lib/Object/ELFObjectFile.cpp @@ -688,11 +688,20 @@ StringRef ELFObjectFileBase::getNVPTXCPUName() const { case ELF::EF_CUDA_SM100: return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_100a" : "sm_100"; + case ELF::EF_CUDA_SM101: + return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_101a" + : "sm_101"; + case ELF::EF_CUDA_SM103: + return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_103a" + : "sm_103"; // Rubin architecture. case ELF::EF_CUDA_SM120: return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_120a" : "sm_120"; + case ELF::EF_CUDA_SM121: + return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_121a" + : "sm_121"; default: llvm_unreachable("Unknown EF_CUDA_SM value"); } diff --git a/llvm/lib/ObjectYAML/ELFEmitter.cpp b/llvm/lib/ObjectYAML/ELFEmitter.cpp index 6de87a8..bc5c68d 100644 --- a/llvm/lib/ObjectYAML/ELFEmitter.cpp +++ b/llvm/lib/ObjectYAML/ELFEmitter.cpp @@ -481,7 +481,11 @@ void ELFState<ELFT>::writeELFHeader(raw_ostream &OS) { Header.e_version = EV_CURRENT; Header.e_entry = Doc.Header.Entry; - Header.e_flags = Doc.Header.Flags; + if (Doc.Header.Flags) + Header.e_flags = *Doc.Header.Flags; + else + Header.e_flags = 0; + Header.e_ehsize = sizeof(Elf_Ehdr); if (Doc.Header.EPhOff) diff --git a/llvm/lib/ObjectYAML/ELFYAML.cpp b/llvm/lib/ObjectYAML/ELFYAML.cpp index 7fcabb68..c27339d 100644 --- a/llvm/lib/ObjectYAML/ELFYAML.cpp +++ b/llvm/lib/ObjectYAML/ELFYAML.cpp @@ -1160,7 +1160,7 @@ void MappingTraits<ELFYAML::FileHeader>::mapping(IO &IO, IO.mapOptional("ABIVersion", FileHdr.ABIVersion, Hex8(0)); IO.mapRequired("Type", FileHdr.Type); IO.mapOptional("Machine", FileHdr.Machine); - IO.mapOptional("Flags", FileHdr.Flags, ELFYAML::ELF_EF(0)); + IO.mapOptional("Flags", FileHdr.Flags); IO.mapOptional("Entry", FileHdr.Entry, Hex64(0)); IO.mapOptional("SectionHeaderStringTable", FileHdr.SectionHeaderStringTable); diff --git a/llvm/lib/Remarks/RemarkLinker.cpp b/llvm/lib/Remarks/RemarkLinker.cpp index 0ca6217..b00419b 100644 --- a/llvm/lib/Remarks/RemarkLinker.cpp +++ b/llvm/lib/Remarks/RemarkLinker.cpp @@ -70,8 +70,8 @@ Error RemarkLinker::link(StringRef Buffer, Format RemarkFormat) { Expected<std::unique_ptr<RemarkParser>> MaybeParser = createRemarkParserFromMeta( RemarkFormat, Buffer, - PrependPath ? std::optional<StringRef>(StringRef(*PrependPath)) - : std::optional<StringRef>()); + PrependPath ? std::make_optional<StringRef>(*PrependPath) + : std::nullopt); if (!MaybeParser) return MaybeParser.takeError(); diff --git a/llvm/lib/Support/BLAKE3/CMakeLists.txt b/llvm/lib/Support/BLAKE3/CMakeLists.txt index eae2b02..90311ae 100644 --- a/llvm/lib/Support/BLAKE3/CMakeLists.txt +++ b/llvm/lib/Support/BLAKE3/CMakeLists.txt @@ -26,7 +26,8 @@ endmacro() if (CAN_USE_ASSEMBLER) if (MSVC) check_symbol_exists(_M_X64 "" IS_X64) - if (IS_X64) + check_symbol_exists(_M_ARM64EC "" IS_ARM64EC) + if (IS_X64 AND NOT IS_ARM64EC) enable_language(ASM_MASM) set(LLVM_BLAKE3_ASM_FILES blake3_sse2_x86-64_windows_msvc.asm diff --git a/llvm/lib/Support/FileCollector.cpp b/llvm/lib/Support/FileCollector.cpp index 29436f8..edb5313 100644 --- a/llvm/lib/Support/FileCollector.cpp +++ b/llvm/lib/Support/FileCollector.cpp @@ -313,5 +313,6 @@ private: IntrusiveRefCntPtr<vfs::FileSystem> FileCollector::createCollectorVFS(IntrusiveRefCntPtr<vfs::FileSystem> BaseFS, std::shared_ptr<FileCollector> Collector) { - return new FileCollectorFileSystem(std::move(BaseFS), std::move(Collector)); + return makeIntrusiveRefCnt<FileCollectorFileSystem>(std::move(BaseFS), + std::move(Collector)); } diff --git a/llvm/lib/Support/VirtualFileSystem.cpp b/llvm/lib/Support/VirtualFileSystem.cpp index e489282..5d42488 100644 --- a/llvm/lib/Support/VirtualFileSystem.cpp +++ b/llvm/lib/Support/VirtualFileSystem.cpp @@ -397,7 +397,8 @@ void RealFileSystem::printImpl(raw_ostream &OS, PrintType Type, } IntrusiveRefCntPtr<FileSystem> vfs::getRealFileSystem() { - static IntrusiveRefCntPtr<FileSystem> FS(new RealFileSystem(true)); + static IntrusiveRefCntPtr<FileSystem> FS = + makeIntrusiveRefCnt<RealFileSystem>(true); return FS; } @@ -2217,9 +2218,9 @@ RedirectingFileSystem::create(std::unique_ptr<MemoryBuffer> Buffer, std::unique_ptr<RedirectingFileSystem> RedirectingFileSystem::create( ArrayRef<std::pair<std::string, std::string>> RemappedFiles, - bool UseExternalNames, FileSystem &ExternalFS) { + bool UseExternalNames, llvm::IntrusiveRefCntPtr<FileSystem> ExternalFS) { std::unique_ptr<RedirectingFileSystem> FS( - new RedirectingFileSystem(&ExternalFS)); + new RedirectingFileSystem(ExternalFS)); FS->UseExternalNames = UseExternalNames; StringMap<RedirectingFileSystem::Entry *> Entries; @@ -2228,7 +2229,7 @@ std::unique_ptr<RedirectingFileSystem> RedirectingFileSystem::create( SmallString<128> From = StringRef(Mapping.first); SmallString<128> To = StringRef(Mapping.second); { - auto EC = ExternalFS.makeAbsolute(From); + auto EC = ExternalFS->makeAbsolute(From); (void)EC; assert(!EC && "Could not make absolute path"); } @@ -2250,7 +2251,7 @@ std::unique_ptr<RedirectingFileSystem> RedirectingFileSystem::create( } assert(Parent && "File without a directory?"); { - auto EC = ExternalFS.makeAbsolute(To); + auto EC = ExternalFS->makeAbsolute(To); (void)EC; assert(!EC && "Could not make absolute path"); } diff --git a/llvm/lib/Support/Windows/Threading.inc b/llvm/lib/Support/Windows/Threading.inc index 8dd7c88..b11f216 100644 --- a/llvm/lib/Support/Windows/Threading.inc +++ b/llvm/lib/Support/Windows/Threading.inc @@ -136,6 +136,7 @@ HMODULE loadSystemModuleSecure(LPCWSTR lpModuleName) { } // namespace llvm::sys::windows SetThreadPriorityResult llvm::set_thread_priority(ThreadPriority Priority) { +#ifdef THREAD_POWER_THROTTLING_CURRENT_VERSION HMODULE kernelM = llvm::sys::windows::loadSystemModuleSecure(L"kernel32.dll"); if (kernelM) { // SetThreadInformation is only available on Windows 8 and later. Since we @@ -166,6 +167,7 @@ SetThreadPriorityResult llvm::set_thread_priority(ThreadPriority Priority) { : 0); } } +#endif // https://docs.microsoft.com/en-us/windows/desktop/api/processthreadsapi/nf-processthreadsapi-setthreadpriority // Begin background processing mode. The system lowers the resource scheduling diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 4f6e3dd..8312b04 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -162,10 +162,10 @@ static cl::opt<bool> UseFEATCPACodegen( cl::init(false)); /// Value type used for condition codes. -static const MVT MVT_CC = MVT::i32; +constexpr MVT CondCodeVT = MVT::i32; /// Value type used for NZCV flags. -static constexpr MVT FlagsVT = MVT::i32; +constexpr MVT FlagsVT = MVT::i32; static const MCPhysReg GPRArgRegs[] = {AArch64::X0, AArch64::X1, AArch64::X2, AArch64::X3, AArch64::X4, AArch64::X5, @@ -3472,6 +3472,12 @@ static void changeVectorFPCCToAArch64CC(ISD::CondCode CC, } } +/// Like SelectionDAG::getCondCode(), but for AArch64 condition codes. +static SDValue getCondCode(SelectionDAG &DAG, AArch64CC::CondCode CC) { + // TODO: Should be TargetConstant (need to s/imm/timm in patterns). + return DAG.getConstant(CC, SDLoc(), CondCodeVT); +} + static bool isLegalArithImmed(uint64_t C) { // Matches AArch64DAGToDAGISel::SelectArithImmed(). bool IsLegal = (C >> 12 == 0) || ((C & 0xFFFULL) == 0 && C >> 24 == 0); @@ -3678,7 +3684,7 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS, if (Opcode == 0) Opcode = AArch64ISD::CCMP; - SDValue Condition = DAG.getConstant(Predicate, DL, MVT_CC); + SDValue Condition = getCondCode(DAG, Predicate); AArch64CC::CondCode InvOutCC = AArch64CC::getInvertedCondCode(OutCC); unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvOutCC); SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32); @@ -4075,7 +4081,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, Cmp = emitComparison(LHS, RHS, CC, DL, DAG); AArch64CC = changeIntCCToAArch64CC(CC); } - AArch64cc = DAG.getConstant(AArch64CC, DL, MVT_CC); + AArch64cc = getCondCode(DAG, AArch64CC); return Cmp; } @@ -4195,7 +4201,7 @@ SDValue AArch64TargetLowering::LowerXOR(SDValue Op, SelectionDAG &DAG) const { AArch64CC::CondCode CC; SDValue Value, Overflow; std::tie(Value, Overflow) = getAArch64XALUOOp(CC, Sel.getValue(0), DAG); - SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, getInvertedCondCode(CC)); return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal, CCVal, Overflow); } @@ -4274,8 +4280,8 @@ static SDValue carryFlagToValue(SDValue Glue, EVT VT, SelectionDAG &DAG, SDLoc DL(Glue); SDValue Zero = DAG.getConstant(0, DL, VT); SDValue One = DAG.getConstant(1, DL, VT); - unsigned Cond = Invert ? AArch64CC::LO : AArch64CC::HS; - SDValue CC = DAG.getConstant(Cond, DL, MVT::i32); + AArch64CC::CondCode Cond = Invert ? AArch64CC::LO : AArch64CC::HS; + SDValue CC = getCondCode(DAG, Cond); return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Glue); } @@ -4285,7 +4291,7 @@ static SDValue overflowFlagToValue(SDValue Glue, EVT VT, SelectionDAG &DAG) { SDLoc DL(Glue); SDValue Zero = DAG.getConstant(0, DL, VT); SDValue One = DAG.getConstant(1, DL, VT); - SDValue CC = DAG.getConstant(AArch64CC::VS, DL, MVT::i32); + SDValue CC = getCondCode(DAG, AArch64CC::VS); return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Glue); } @@ -4334,7 +4340,7 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { // We use an inverted condition, because the conditional select is inverted // too. This will allow it to be selected to a single instruction: // CSINC Wd, WZR, WZR, invert(cond). - SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, getInvertedCondCode(CC)); Overflow = DAG.getNode(AArch64ISD::CSEL, DL, MVT::i32, FVal, TVal, CCVal, Overflow); @@ -7124,8 +7130,7 @@ SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const { SDValue Cmp = DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT), Op.getOperand(0), DAG.getConstant(0, DL, VT)); return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg, - DAG.getConstant(AArch64CC::PL, DL, MVT::i32), - Cmp.getValue(1)); + getCondCode(DAG, AArch64CC::PL), Cmp.getValue(1)); } static SDValue LowerBRCOND(SDValue Op, SelectionDAG &DAG) { @@ -7136,7 +7141,7 @@ static SDValue LowerBRCOND(SDValue Op, SelectionDAG &DAG) { AArch64CC::CondCode CC; if (SDValue Cmp = emitConjunction(DAG, Cond, CC)) { SDLoc DL(Op); - SDValue CCVal = DAG.getConstant(CC, DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, CC); return DAG.getNode(AArch64ISD::BRCOND, DL, MVT::Other, Chain, Dest, CCVal, Cmp); } @@ -10575,7 +10580,7 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { if (CC == ISD::SETNE) OFCC = getInvertedCondCode(OFCC); - SDValue CCVal = DAG.getConstant(OFCC, DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, OFCC); return DAG.getNode(AArch64ISD::BRCOND, DL, MVT::Other, Chain, Dest, CCVal, Overflow); @@ -10648,7 +10653,7 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { AArch64CC::isValidCBCond(changeIntCCToAArch64CC(CC)) && ProduceNonFlagSettingCondBr) { SDValue Cond = - DAG.getTargetConstant(changeIntCCToAArch64CC(CC), DL, MVT::i32); + DAG.getTargetConstant(changeIntCCToAArch64CC(CC), DL, CondCodeVT); return DAG.getNode(AArch64ISD::CB, DL, MVT::Other, Chain, Cond, LHS, RHS, Dest); } @@ -10667,11 +10672,11 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { SDValue Cmp = emitComparison(LHS, RHS, CC, DL, DAG); AArch64CC::CondCode CC1, CC2; changeFPCCToAArch64CC(CC, CC1, CC2); - SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32); + SDValue CC1Val = getCondCode(DAG, CC1); SDValue BR1 = DAG.getNode(AArch64ISD::BRCOND, DL, MVT::Other, Chain, Dest, CC1Val, Cmp); if (CC2 != AArch64CC::AL) { - SDValue CC2Val = DAG.getConstant(CC2, DL, MVT::i32); + SDValue CC2Val = getCondCode(DAG, CC2); return DAG.getNode(AArch64ISD::BRCOND, DL, MVT::Other, BR1, Dest, CC2Val, Cmp); } @@ -11160,7 +11165,7 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { if (CC2 == AArch64CC::AL) { changeFPCCToAArch64CC(ISD::getSetCCInverse(CC, LHS.getValueType()), CC1, CC2); - SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32); + SDValue CC1Val = getCondCode(DAG, CC1); // Note that we inverted the condition above, so we reverse the order of // the true and false operands here. This will allow the setcc to be @@ -11173,11 +11178,11 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { // of the first as the RHS. We're effectively OR'ing the two CC's together. // FIXME: It would be nice if we could match the two CSELs to two CSINCs. - SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32); + SDValue CC1Val = getCondCode(DAG, CC1); SDValue CS1 = DAG.getNode(AArch64ISD::CSEL, DL, VT, TVal, FVal, CC1Val, Cmp); - SDValue CC2Val = DAG.getConstant(CC2, DL, MVT::i32); + SDValue CC2Val = getCondCode(DAG, CC2); Res = DAG.getNode(AArch64ISD::CSEL, DL, VT, TVal, CS1, CC2Val, Cmp); } return IsStrict ? DAG.getMergeValues({Res, Cmp.getValue(1)}, DL) : Res; @@ -11205,8 +11210,7 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op, ISD::CondCode Cond = cast<CondCodeSDNode>(Op.getOperand(3))->get(); ISD::CondCode CondInv = ISD::getSetCCInverse(Cond, VT); - SDValue CCVal = - DAG.getConstant(changeIntCCToAArch64CC(CondInv), DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, changeIntCCToAArch64CC(CondInv)); // Inputs are swapped because the condition is inverted. This will allow // matching with a single CSINC instruction. return DAG.getNode(AArch64ISD::CSEL, DL, OpVT, FVal, TVal, CCVal, @@ -11577,13 +11581,13 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( } // Emit first, and possibly only, CSEL. - SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32); + SDValue CC1Val = getCondCode(DAG, CC1); SDValue CS1 = DAG.getNode(AArch64ISD::CSEL, DL, VT, TVal, FVal, CC1Val, Cmp); // If we need a second CSEL, emit it, using the output of the first as the // RHS. We're effectively OR'ing the two CC's together. if (CC2 != AArch64CC::AL) { - SDValue CC2Val = DAG.getConstant(CC2, DL, MVT::i32); + SDValue CC2Val = getCondCode(DAG, CC2); return DAG.getNode(AArch64ISD::CSEL, DL, VT, TVal, CS1, CC2Val, Cmp); } @@ -11685,7 +11689,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, AArch64CC::CondCode OFCC; SDValue Value, Overflow; std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, CCVal.getValue(0), DAG); - SDValue CCVal = DAG.getConstant(OFCC, DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, OFCC); return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal, CCVal, Overflow); @@ -12525,10 +12529,10 @@ static AArch64CC::CondCode parseConstraintCode(llvm::StringRef Constraint) { /// WZR, invert(<cond>)'. static SDValue getSETCC(AArch64CC::CondCode CC, SDValue NZCV, const SDLoc &DL, SelectionDAG &DAG) { - return DAG.getNode( - AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(getInvertedCondCode(CC), DL, MVT::i32), NZCV); + return DAG.getNode(AArch64ISD::CSINC, DL, MVT::i32, + DAG.getConstant(0, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i32), + getCondCode(DAG, getInvertedCondCode(CC)), NZCV); } // Lower @cc flag output via getSETCC. @@ -18699,7 +18703,7 @@ AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, Created.push_back(Cmp.getNode()); Created.push_back(And.getNode()); } else { - SDValue CCVal = DAG.getConstant(AArch64CC::MI, DL, MVT_CC); + SDValue CCVal = getCondCode(DAG, AArch64CC::MI); SDVTList VTs = DAG.getVTList(VT, FlagsVT); SDValue Negs = DAG.getNode(AArch64ISD::SUBS, DL, VTs, Zero, N0); @@ -19571,11 +19575,11 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) { if (N->getOpcode() == ISD::AND) { AArch64CC::CondCode InvCC0 = AArch64CC::getInvertedCondCode(CC0); - Condition = DAG.getConstant(InvCC0, DL, MVT_CC); + Condition = getCondCode(DAG, InvCC0); NZCV = AArch64CC::getNZCVToSatisfyCondCode(CC1); } else { AArch64CC::CondCode InvCC1 = AArch64CC::getInvertedCondCode(CC1); - Condition = DAG.getConstant(CC0, DL, MVT_CC); + Condition = getCondCode(DAG, CC0); NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvCC1); } @@ -19596,8 +19600,7 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) { Cmp1.getOperand(1), NZCVOp, Condition, Cmp0); } return DAG.getNode(AArch64ISD::CSEL, DL, VT, CSel0.getOperand(0), - CSel0.getOperand(1), DAG.getConstant(CC1, DL, MVT::i32), - CCmp); + CSel0.getOperand(1), getCondCode(DAG, CC1), CCmp); } static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, @@ -19802,7 +19805,7 @@ static SDValue performANDSETCCCombine(SDNode *N, SDLoc DL(N); return DAG.getNode(AArch64ISD::CSINC, DL, VT, DAG.getConstant(0, DL, VT), DAG.getConstant(0, DL, VT), - DAG.getConstant(InvertedCC, DL, MVT::i32), Cmp); + getCondCode(DAG, InvertedCC), Cmp); } } return SDValue(); @@ -20793,7 +20796,7 @@ static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) { "Unexpected constant value"); SDValue NewNode = DAG.getNode(ISD::ADD, DL, VT, RHS, SDValue(CTVal, 0)); - SDValue CCVal = DAG.getConstant(AArch64CC, DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, AArch64CC); SDValue Cmp = LHS.getOperand(3); return DAG.getNode(AArch64ISD::CSINC, DL, VT, NewNode, RHS, CCVal, Cmp); @@ -20979,7 +20982,7 @@ static SDValue foldADCToCINC(SDNode *N, SelectionDAG &DAG) { SDLoc DL(N); // (CINC x cc cond) <=> (CSINC x x !cc cond) - SDValue CC = DAG.getConstant(AArch64CC::LO, DL, MVT::i32); + SDValue CC = getCondCode(DAG, AArch64CC::LO); return DAG.getNode(AArch64ISD::CSINC, DL, VT, LHS, LHS, CC, Cond); } @@ -22052,7 +22055,7 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, // Convert CC to integer based on requested condition. // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare. - SDValue CC = DAG.getConstant(getInvertedCondCode(Cond), DL, MVT::i32); + SDValue CC = getCondCode(DAG, getInvertedCondCode(Cond)); SDValue Res = DAG.getNode(AArch64ISD::CSEL, DL, OutVT, FVal, TVal, CC, Test); return DAG.getZExtOrTrunc(Res, DL, VT); } @@ -25093,10 +25096,9 @@ static SDValue performBRCONDCombine(SDNode *N, auto CSelCC = getCSETCondCode(CSel); if (CSelCC) { SDLoc DL(N); - return DAG.getNode( - N->getOpcode(), DL, N->getVTList(), Chain, Dest, - DAG.getConstant(getInvertedCondCode(*CSelCC), DL, MVT::i32), - CSel.getOperand(3)); + return DAG.getNode(N->getOpcode(), DL, N->getVTList(), Chain, Dest, + getCondCode(DAG, getInvertedCondCode(*CSelCC)), + CSel.getOperand(3)); } } @@ -25237,7 +25239,7 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) { SDLoc DL(Op); EVT VT = Op->getValueType(0); - SDValue CCValue = DAG.getConstant(CC, DL, MVT::i32); + SDValue CCValue = getCondCode(DAG, CC); return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond); } @@ -25314,8 +25316,7 @@ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) { SDValue TValReassoc = Reassociate(TReassocOp, 0); SDValue FValReassoc = Reassociate(FReassocOp, 1); return DAG.getNode(AArch64ISD::CSEL, SDLoc(N), VT, TValReassoc, FValReassoc, - DAG.getConstant(NewCC, SDLoc(N->getOperand(2)), MVT_CC), - NewCmp.getValue(1)); + getCondCode(DAG, NewCC), NewCmp.getValue(1)); }; auto CC = static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2)); @@ -25456,8 +25457,7 @@ static SDValue performCSELCombine(SDNode *N, SDValue Sub = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(), Cond.getOperand(1), Cond.getOperand(0)); return DAG.getNode(AArch64ISD::CSEL, DL, N->getVTList(), N->getOperand(0), - N->getOperand(1), - DAG.getConstant(NewCond, DL, MVT::i32), + N->getOperand(1), getCondCode(DAG, NewCond), Sub.getValue(1)); } } @@ -25557,10 +25557,9 @@ static SDValue performSETCCCombine(SDNode *N, auto NewCond = getInvertedCondCode(OldCond); // csel 0, 1, !cond, X - SDValue CSEL = - DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), LHS.getOperand(0), - LHS.getOperand(1), DAG.getConstant(NewCond, DL, MVT::i32), - LHS.getOperand(3)); + SDValue CSEL = DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), + LHS.getOperand(0), LHS.getOperand(1), + getCondCode(DAG, NewCond), LHS.getOperand(3)); return DAG.getZExtOrTrunc(CSEL, DL, VT); } @@ -25630,8 +25629,7 @@ static SDValue performFlagSettingCombine(SDNode *N, // If the flag result isn't used, convert back to a generic opcode. if (!N->hasAnyUseOfValue(1)) { SDValue Res = DCI.DAG.getNode(GenericOpcode, DL, VT, N->ops()); - return DCI.DAG.getMergeValues({Res, DCI.DAG.getConstant(0, DL, MVT::i32)}, - DL); + return DCI.CombineTo(N, Res, SDValue(N, 1)); } // Combine identical generic nodes into this node, re-using the result. @@ -27013,10 +27011,10 @@ static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) { SDValue A = DAG.getNode( AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, FlagsVT, MVT::Other), N->getOperand(0), DAG.getConstant(Register, DL, MVT::i32)); - SDValue B = DAG.getNode( - AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1)); + SDValue B = DAG.getNode(AArch64ISD::CSINC, DL, MVT::i32, + DAG.getConstant(0, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i32), + getCondCode(DAG, AArch64CC::NE), A.getValue(1)); return DAG.getMergeValues( {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL); } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 251fd44..ac31236 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -448,8 +448,13 @@ def SDTBinaryArithWithFlagsInOut : SDTypeProfile<2, 3, SDTCisVT<1, FlagsVT>, SDTCisVT<4, FlagsVT>]>; +// Value type used for condition codes. +// Should be kept in sync with its C++ counterpart. +defvar CondCodeVT = i32; + def SDT_AArch64Brcond : SDTypeProfile<0, 3, - [SDTCisVT<0, OtherVT>, SDTCisVT<1, i32>, + [SDTCisVT<0, OtherVT>, + SDTCisVT<1, CondCodeVT>, SDTCisVT<2, FlagsVT>]>; def SDT_AArch64cbz : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisVT<1, OtherVT>]>; def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>, @@ -458,22 +463,22 @@ def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>, def SDT_AArch64CSel : SDTypeProfile<1, 4, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, - SDTCisInt<3>, + SDTCisVT<3, CondCodeVT>, SDTCisVT<4, FlagsVT>]>; def SDT_AArch64CCMP : SDTypeProfile<1, 5, [SDTCisVT<0, FlagsVT>, SDTCisInt<1>, SDTCisSameAs<1, 2>, SDTCisInt<3>, - SDTCisInt<4>, - SDTCisVT<5, i32>]>; + SDTCisVT<4, CondCodeVT>, + SDTCisVT<5, FlagsVT>]>; def SDT_AArch64FCCMP : SDTypeProfile<1, 5, [SDTCisVT<0, FlagsVT>, SDTCisFP<1>, SDTCisSameAs<1, 2>, SDTCisInt<3>, - SDTCisInt<4>, - SDTCisVT<5, i32>]>; + SDTCisVT<4, CondCodeVT>, + SDTCisVT<5, FlagsVT>]>; def SDT_AArch64FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, FlagsVT>, SDTCisFP<1>, SDTCisSameAs<2, 1>]>; @@ -546,7 +551,8 @@ def SDT_AArch64TBL : SDTypeProfile<1, 2, [ ]>; def SDT_AArch64cb : SDTypeProfile<0, 4, - [SDTCisVT<0, i32>, SDTCisInt<1>, SDTCisInt<2>, + [SDTCisVT<0, CondCodeVT>, + SDTCisInt<1>, SDTCisInt<2>, SDTCisVT<3, OtherVT>]>; // non-extending masked load fragment. diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 18ca22f..e1adc0b 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -270,6 +270,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, const Function *Callee) const { SMECallAttrs CallAttrs(*Caller, *Callee); + // Never inline a function explicitly marked as being streaming, + // into a non-streaming function. Assume it was marked as streaming + // for a reason. + if (CallAttrs.caller().hasNonStreamingInterfaceAndBody() && + CallAttrs.callee().hasStreamingInterfaceOrBody()) + return false; + // When inlining, we should consider the body of the function, not the // interface. if (CallAttrs.callee().hasStreamingBody()) { diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp index b9d3e1b..6912caf 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp @@ -461,7 +461,7 @@ void AArch64AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, Value <<= Info.TargetOffset; unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // Used to point to big endian bytes. unsigned FulleSizeInBytes = getFixupKindContainereSizeInBytes(Fixup.getKind()); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp index 2991778..19b8757 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp @@ -204,7 +204,7 @@ MetadataStreamerMsgPackV4::getWorkGroupDimensions(MDNode *Node) const { for (auto &Op : Node->operands()) Dims.push_back(Dims.getDocument()->getNode( - uint64_t(mdconst::extract<ConstantInt>(Op)->getZExtValue()))); + mdconst::extract<ConstantInt>(Op)->getZExtValue())); return Dims; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp index 6bca2fe..c8e45d4 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp @@ -4574,6 +4574,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_cvt_pknorm_u16: case Intrinsic::amdgcn_cvt_pk_i16: case Intrinsic::amdgcn_cvt_pk_u16: + case Intrinsic::amdgcn_cvt_sr_pk_f16_f32: case Intrinsic::amdgcn_cvt_sr_pk_bf16_f32: case Intrinsic::amdgcn_cvt_pk_f16_fp8: case Intrinsic::amdgcn_cvt_pk_f16_bf8: @@ -4581,6 +4582,15 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_cvt_pk_bf8_f16: case Intrinsic::amdgcn_cvt_sr_fp8_f16: case Intrinsic::amdgcn_cvt_sr_bf8_f16: + case Intrinsic::amdgcn_cvt_scale_pk8_f16_fp8: + case Intrinsic::amdgcn_cvt_scale_pk8_bf16_fp8: + case Intrinsic::amdgcn_cvt_scale_pk8_f16_bf8: + case Intrinsic::amdgcn_cvt_scale_pk8_bf16_bf8: + case Intrinsic::amdgcn_cvt_scale_pk8_f16_fp4: + case Intrinsic::amdgcn_cvt_scale_pk8_bf16_fp4: + case Intrinsic::amdgcn_cvt_scale_pk8_f32_fp8: + case Intrinsic::amdgcn_cvt_scale_pk8_f32_bf8: + case Intrinsic::amdgcn_cvt_scale_pk8_f32_fp4: case Intrinsic::amdgcn_sat_pk4_i4_i8: case Intrinsic::amdgcn_sat_pk4_u4_u8: case Intrinsic::amdgcn_fmed3: @@ -4632,8 +4642,10 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_cvt_pk_f32_fp8: case Intrinsic::amdgcn_cvt_pk_f32_bf8: case Intrinsic::amdgcn_cvt_pk_fp8_f32: + case Intrinsic::amdgcn_cvt_pk_fp8_f32_e5m3: case Intrinsic::amdgcn_cvt_pk_bf8_f32: case Intrinsic::amdgcn_cvt_sr_fp8_f32: + case Intrinsic::amdgcn_cvt_sr_fp8_f32_e5m3: case Intrinsic::amdgcn_cvt_sr_bf8_f32: case Intrinsic::amdgcn_cvt_sr_bf16_f32: case Intrinsic::amdgcn_cvt_sr_f16_f32: diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index a4ea8cf..a83caa0 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -180,6 +180,7 @@ public: ImmTyMatrixBFMT, ImmTyMatrixAReuse, ImmTyMatrixBReuse, + ImmTyScaleSel, ImmTyByteSel, }; @@ -1184,6 +1185,7 @@ public: case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break; case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break; case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break; + case ImmTyScaleSel: OS << "ScaleSel" ; break; case ImmTyByteSel: OS << "ByteSel" ; break; } // clang-format on @@ -9366,6 +9368,14 @@ void AMDGPUAsmParser::cvtVOP3(MCInst &Inst, const OperandVector &Operands, } } + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::scale_sel)) + addOptionalImmOperand(Inst, Operands, OptionalIdx, + AMDGPUOperand::ImmTyScaleSel); + + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::clamp)) + addOptionalImmOperand(Inst, Operands, OptionalIdx, + AMDGPUOperand::ImmTyClamp); + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::byte_sel)) { if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::vdst_in)) Inst.addOperand(Inst.getOperand(0)); @@ -9373,10 +9383,6 @@ void AMDGPUAsmParser::cvtVOP3(MCInst &Inst, const OperandVector &Operands, AMDGPUOperand::ImmTyByteSel); } - if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::clamp)) - addOptionalImmOperand(Inst, Operands, OptionalIdx, - AMDGPUOperand::ImmTyClamp); - if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::omod)) addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyOModSI); @@ -9430,6 +9436,8 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands, Opc == AMDGPU::V_CVT_PK_FP8_F32_fake16_e64_dpp8_gfx12 || Opc == AMDGPU::V_CVT_SR_FP8_F32_gfx12_e64_dpp_gfx12 || Opc == AMDGPU::V_CVT_SR_FP8_F32_gfx12_e64_dpp8_gfx12 || + Opc == AMDGPU::V_CVT_SR_FP8_F32_gfx1250_e64_dpp_gfx1250 || + Opc == AMDGPU::V_CVT_SR_FP8_F32_gfx1250_e64_dpp8_gfx1250 || Opc == AMDGPU::V_CVT_SR_BF8_F32_gfx12_e64_dpp_gfx12 || Opc == AMDGPU::V_CVT_SR_BF8_F32_gfx12_e64_dpp8_gfx12 || Opc == AMDGPU::V_CVT_SR_FP8_F16_t16_e64_dpp_gfx1250 || @@ -10038,9 +10046,12 @@ void AMDGPUAsmParser::cvtVOP3DPP(MCInst &Inst, const OperandVector &Operands, addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyClamp); - if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::byte_sel)) + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::byte_sel)) { + if (VdstInIdx == static_cast<int>(Inst.getNumOperands())) + Inst.addOperand(Inst.getOperand(0)); addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyByteSel); + } if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::omod)) addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyOModSI); diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp index ce1ce68..96d5668 100644 --- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp @@ -592,10 +592,13 @@ bool GCNMaxILPSchedStrategy::tryCandidate(SchedCandidate &Cand, // This is a best effort to set things up for a post-RA pass. Optimizations // like generating loads of multiple registers should ideally be done within // the scheduler pass by combining the loads during DAG postprocessing. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; @@ -666,10 +669,13 @@ bool GCNMaxMemoryClauseSchedStrategy::tryCandidate(SchedCandidate &Cand, // MaxMemoryClause-specific: We prioritize clustered instructions as we would // get more benefit from clausing these memory instructions. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; @@ -896,15 +902,10 @@ GCNScheduleDAGMILive::getRegionLiveInMap() const { assert(!Regions.empty()); std::vector<MachineInstr *> RegionFirstMIs; RegionFirstMIs.reserve(Regions.size()); - auto I = Regions.rbegin(), E = Regions.rend(); - do { - const MachineBasicBlock *MBB = I->first->getParent(); - auto *MI = &*skipDebugInstructionsForward(I->first, I->second); - RegionFirstMIs.push_back(MI); - do { - ++I; - } while (I != E && I->first->getParent() == MBB); - } while (I != E); + for (auto &[RegionBegin, RegionEnd] : reverse(Regions)) + RegionFirstMIs.push_back( + &*skipDebugInstructionsForward(RegionBegin, RegionEnd)); + return getLiveRegMap(RegionFirstMIs, /*After=*/false, *LIS); } @@ -941,11 +942,9 @@ void GCNScheduleDAGMILive::finalizeSchedule() { Pressure.resize(Regions.size()); RegionsWithHighRP.resize(Regions.size()); RegionsWithExcessRP.resize(Regions.size()); - RegionsWithMinOcc.resize(Regions.size()); RegionsWithIGLPInstrs.resize(Regions.size()); RegionsWithHighRP.reset(); RegionsWithExcessRP.reset(); - RegionsWithMinOcc.reset(); RegionsWithIGLPInstrs.reset(); runSchedStages(); @@ -1095,8 +1094,7 @@ bool PreRARematStage::initGCNSchedStage() { // fixed if there is another pass after this pass. assert(!S.hasNextStage()); - if (!GCNSchedStage::initGCNSchedStage() || DAG.RegionsWithMinOcc.none() || - DAG.Regions.size() == 1) + if (!GCNSchedStage::initGCNSchedStage() || DAG.Regions.size() == 1) return false; // Before performing any IR modification record the parent region of each MI @@ -1138,11 +1136,6 @@ void UnclusteredHighRPStage::finalizeGCNSchedStage() { SavedMutations.swap(DAG.Mutations); S.SGPRLimitBias = S.VGPRLimitBias = 0; if (DAG.MinOccupancy > InitialOccupancy) { - for (unsigned IDX = 0; IDX < DAG.Pressure.size(); ++IDX) - DAG.RegionsWithMinOcc[IDX] = - DAG.Pressure[IDX].getOccupancy( - DAG.ST, DAG.MFI.getDynamicVGPRBlockSize()) == DAG.MinOccupancy; - LLVM_DEBUG(dbgs() << StageID << " stage successfully increased occupancy to " << DAG.MinOccupancy << '\n'); @@ -1214,11 +1207,15 @@ bool GCNSchedStage::initGCNRegion() { } bool UnclusteredHighRPStage::initGCNRegion() { - // Only reschedule regions with the minimum occupancy or regions that may have - // spilling (excess register pressure). - if ((!DAG.RegionsWithMinOcc[RegionIdx] || - DAG.MinOccupancy <= InitialOccupancy) && - !DAG.RegionsWithExcessRP[RegionIdx]) + // Only reschedule regions that have excess register pressure (i.e. spilling) + // or had minimum occupancy at the beginning of the stage (as long as + // rescheduling of previous regions did not make occupancy drop back down to + // the initial minimum). + unsigned DynamicVGPRBlockSize = DAG.MFI.getDynamicVGPRBlockSize(); + if (!DAG.RegionsWithExcessRP[RegionIdx] && + (DAG.MinOccupancy <= InitialOccupancy || + DAG.Pressure[RegionIdx].getOccupancy(ST, DynamicVGPRBlockSize) != + InitialOccupancy)) return false; return GCNSchedStage::initGCNRegion(); @@ -1283,9 +1280,6 @@ void GCNSchedStage::checkScheduling() { if (PressureAfter.getSGPRNum() <= S.SGPRCriticalLimit && PressureAfter.getVGPRNum(ST.hasGFX90AInsts()) <= S.VGPRCriticalLimit) { DAG.Pressure[RegionIdx] = PressureAfter; - DAG.RegionsWithMinOcc[RegionIdx] = - PressureAfter.getOccupancy(ST, DynamicVGPRBlockSize) == - DAG.MinOccupancy; // Early out if we have achieved the occupancy target. LLVM_DEBUG(dbgs() << "Pressure in desired limits, done.\n"); @@ -1319,7 +1313,6 @@ void GCNSchedStage::checkScheduling() { if (NewOccupancy < DAG.MinOccupancy) { DAG.MinOccupancy = NewOccupancy; MFI.limitOccupancy(DAG.MinOccupancy); - DAG.RegionsWithMinOcc.reset(); LLVM_DEBUG(dbgs() << "Occupancy lowered for the function to " << DAG.MinOccupancy << ".\n"); } @@ -1341,14 +1334,10 @@ void GCNSchedStage::checkScheduling() { // Revert if this region's schedule would cause a drop in occupancy or // spilling. - if (shouldRevertScheduling(WavesAfter)) { + if (shouldRevertScheduling(WavesAfter)) revertScheduling(); - } else { + else DAG.Pressure[RegionIdx] = PressureAfter; - DAG.RegionsWithMinOcc[RegionIdx] = - PressureAfter.getOccupancy(ST, DynamicVGPRBlockSize) == - DAG.MinOccupancy; - } } unsigned @@ -1578,9 +1567,6 @@ bool GCNSchedStage::mayCauseSpilling(unsigned WavesAfter) { } void GCNSchedStage::revertScheduling() { - DAG.RegionsWithMinOcc[RegionIdx] = - PressureBefore.getOccupancy(ST, DAG.MFI.getDynamicVGPRBlockSize()) == - DAG.MinOccupancy; LLVM_DEBUG(dbgs() << "Attempting to revert scheduling.\n"); DAG.RegionEnd = DAG.RegionBegin; int SkippedDebugInstr = 0; diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h index 94cd795..32139a9 100644 --- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h +++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h @@ -250,9 +250,6 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive { // limit. Register pressure in these regions usually will result in spilling. BitVector RegionsWithExcessRP; - // Regions that has the same occupancy as the latest MinOccupancy - BitVector RegionsWithMinOcc; - // Regions that have IGLP instructions (SCHED_GROUP_BARRIER or IGLP_OPT). BitVector RegionsWithIGLPInstrs; diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp index 0a0a107..0237a60 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp @@ -340,6 +340,43 @@ void GCNSubtarget::overrideSchedPolicy(MachineSchedPolicy &Policy, Policy.ShouldTrackLaneMasks = true; } +void GCNSubtarget::overridePostRASchedPolicy(MachineSchedPolicy &Policy, + const SchedRegion &Region) const { + const Function &F = Region.RegionBegin->getMF()->getFunction(); + Attribute PostRADirectionAttr = F.getFnAttribute("amdgpu-post-ra-direction"); + if (!PostRADirectionAttr.isValid()) + return; + + StringRef PostRADirectionStr = PostRADirectionAttr.getValueAsString(); + if (PostRADirectionStr == "topdown") { + Policy.OnlyTopDown = true; + Policy.OnlyBottomUp = false; + } else if (PostRADirectionStr == "bottomup") { + Policy.OnlyTopDown = false; + Policy.OnlyBottomUp = true; + } else if (PostRADirectionStr == "bidirectional") { + Policy.OnlyTopDown = false; + Policy.OnlyBottomUp = false; + } else { + DiagnosticInfoOptimizationFailure Diag( + F, F.getSubprogram(), "invalid value for postRA direction attribute"); + F.getContext().diagnose(Diag); + } + + LLVM_DEBUG({ + const char *DirStr = "default"; + if (Policy.OnlyTopDown && !Policy.OnlyBottomUp) + DirStr = "topdown"; + else if (!Policy.OnlyTopDown && Policy.OnlyBottomUp) + DirStr = "bottomup"; + else if (!Policy.OnlyTopDown && !Policy.OnlyBottomUp) + DirStr = "bidirectional"; + + dbgs() << "Post-MI-sched direction (" << F.getName() << "): " << DirStr + << '\n'; + }); +} + void GCNSubtarget::mirFileLoaded(MachineFunction &MF) const { if (isWave32()) { // Fix implicit $vcc operands after MIParser has verified that they match diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h index bdd900d..6fe3abc 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h @@ -1041,6 +1041,9 @@ public: void overrideSchedPolicy(MachineSchedPolicy &Policy, const SchedRegion &Region) const override; + void overridePostRASchedPolicy(MachineSchedPolicy &Policy, + const SchedRegion &Region) const override; + void mirFileLoaded(MachineFunction &MF) const override; unsigned getMaxNumUserSGPRs() const { diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp index 2a920f6..86d56855 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp @@ -149,7 +149,7 @@ void AMDGPUAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned NumBytes = getFixupKindNumBytes(Fixup.getKind()); uint32_t Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the bits from // the fixup value. diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp index 15088ac..42c4d8b 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp @@ -1793,4 +1793,14 @@ void AMDGPUInstPrinter::printBitOp3(const MCInst *MI, unsigned OpNo, O << formatHex(static_cast<uint64_t>(Imm)); } +void AMDGPUInstPrinter::printScaleSel(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O) { + uint8_t Imm = MI->getOperand(OpNo).getImm(); + if (!Imm) + return; + + O << " scale_sel:" << formatDec(Imm); +} + #include "AMDGPUGenAsmWriter.inc" diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h index e0b7aa5..f6739b14 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h @@ -173,6 +173,8 @@ private: const MCSubtargetInfo &STI, raw_ostream &O, StringRef Prefix, bool PrintInHex, bool AlwaysPrint); + void printScaleSel(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O); void printBitOp3(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI, raw_ostream &O); diff --git a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp index 11552b3..9b348d4 100644 --- a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp @@ -983,6 +983,7 @@ void SIFrameLowering::emitCSRSpillStores( const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); const SIInstrInfo *TII = ST.getInstrInfo(); const SIRegisterInfo &TRI = TII->getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); // Spill Whole-Wave Mode VGPRs. Save only the inactive lanes of the scratch // registers. However, save all lanes of callee-saved VGPRs. Due to this, we @@ -1005,6 +1006,12 @@ void SIFrameLowering::emitCSRSpillStores( } }; + for (const Register Reg : make_first_range(WWMScratchRegs)) { + if (!MRI.isReserved(Reg)) { + MRI.addLiveIn(Reg); + MBB.addLiveIn(Reg); + } + } StoreWWMRegisters(WWMScratchRegs); auto EnableAllLanes = [&]() { diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index ad26757..4d67e4a 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -16825,56 +16825,51 @@ SITargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI_, return std::pair(0U, RC); } - if (Constraint.starts_with("{") && Constraint.ends_with("}")) { - StringRef RegName(Constraint.data() + 1, Constraint.size() - 2); - if (RegName.consume_front("v")) { + auto [Kind, Idx, NumRegs] = AMDGPU::parseAsmConstraintPhysReg(Constraint); + if (Kind != '\0') { + if (Kind == 'v') { RC = &AMDGPU::VGPR_32RegClass; - } else if (RegName.consume_front("s")) { + } else if (Kind == 's') { RC = &AMDGPU::SGPR_32RegClass; - } else if (RegName.consume_front("a")) { + } else if (Kind == 'a') { RC = &AMDGPU::AGPR_32RegClass; } if (RC) { - uint32_t Idx; - if (RegName.consume_front("[")) { - uint32_t End; - bool Failed = RegName.consumeInteger(10, Idx); - Failed |= !RegName.consume_front(":"); - Failed |= RegName.consumeInteger(10, End); - Failed |= !RegName.consume_back("]"); - if (!Failed) { - uint32_t Width = (End - Idx + 1) * 32; - // Prohibit constraints for register ranges with a width that does not - // match the required type. - if (VT.SimpleTy != MVT::Other && Width != VT.getSizeInBits()) + if (NumRegs > 1) { + if (Idx >= RC->getNumRegs() || Idx + NumRegs - 1 > RC->getNumRegs()) + return std::pair(0U, nullptr); + + uint32_t Width = NumRegs * 32; + // Prohibit constraints for register ranges with a width that does not + // match the required type. + if (VT.SimpleTy != MVT::Other && Width != VT.getSizeInBits()) + return std::pair(0U, nullptr); + + MCRegister Reg = RC->getRegister(Idx); + if (SIRegisterInfo::isVGPRClass(RC)) + RC = TRI->getVGPRClassForBitWidth(Width); + else if (SIRegisterInfo::isSGPRClass(RC)) + RC = TRI->getSGPRClassForBitWidth(Width); + else if (SIRegisterInfo::isAGPRClass(RC)) + RC = TRI->getAGPRClassForBitWidth(Width); + if (RC) { + Reg = TRI->getMatchingSuperReg(Reg, AMDGPU::sub0, RC); + if (!Reg) { + // The register class does not contain the requested register, + // e.g., because it is an SGPR pair that would violate alignment + // requirements. return std::pair(0U, nullptr); - MCRegister Reg = RC->getRegister(Idx); - if (SIRegisterInfo::isVGPRClass(RC)) - RC = TRI->getVGPRClassForBitWidth(Width); - else if (SIRegisterInfo::isSGPRClass(RC)) - RC = TRI->getSGPRClassForBitWidth(Width); - else if (SIRegisterInfo::isAGPRClass(RC)) - RC = TRI->getAGPRClassForBitWidth(Width); - if (RC) { - Reg = TRI->getMatchingSuperReg(Reg, AMDGPU::sub0, RC); - if (!Reg) { - // The register class does not contain the requested register, - // e.g., because it is an SGPR pair that would violate alignment - // requirements. - return std::pair(0U, nullptr); - } - return std::pair(Reg, RC); } + return std::pair(Reg, RC); } - } else { - // Check for lossy scalar/vector conversions. - if (VT.isVector() && VT.getSizeInBits() != 32) - return std::pair(0U, nullptr); - bool Failed = RegName.getAsInteger(10, Idx); - if (!Failed && Idx < RC->getNumRegs()) - return std::pair(RC->getRegister(Idx), RC); } + + // Check for lossy scalar/vector conversions. + if (VT.isVector() && VT.getSizeInBits() != 32) + return std::pair(0U, nullptr); + if (Idx < RC->getNumRegs()) + return std::pair(RC->getRegister(Idx), RC); } } diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td index efcc88e..a3e20ba 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td @@ -1313,6 +1313,10 @@ def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">; def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">; def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">; +def ScaleSel : NamedIntOperand<"scale_sel"> { + let Validator = "isUInt<3>"; +} + class KImmFPOperand<ValueType vt> : ImmOperand<vt> { let OperandNamespace = "AMDGPU"; let OperandType = "OPERAND_KIMM"#vt.Size; @@ -2928,6 +2932,7 @@ def VOP_V32F32_V6I32_F32 : VOPProfile <[v32f32, v6i32, f32, untyped]>; def VOP_V32F16_V6I32_F32 : VOPProfile <[v32f16, v6i32, f32, untyped]>; def VOP_V32BF16_V6I32_F32 : VOPProfile <[v32bf16, v6i32, f32, untyped]>; def VOP_V2BF16_F32_F32_I32 : VOPProfile <[v2bf16, f32, f32, i32]>; +def VOP_V2F16_F32_F32_I32 : VOPProfile <[v2f16, f32, f32, i32]>; def VOP_V6I32_V32F16_F32 : VOPProfile<[v6i32, v32f16, f32, untyped]>; def VOP_V6I32_V32BF16_F32 : VOPProfile<[v6i32, v32bf16, f32, untyped]>; def VOP_V6I32_V16F32_V16F32_F32 : VOPProfile<[v6i32, v16f32, v16f32, f32]>; @@ -2943,6 +2948,13 @@ def VOP_BF16_F32_I32 : VOPProfile<[bf16, f32, i32, untyped]>; def VOP_F16_F32_I32 : VOPProfile<[f16, f32, i32, untyped]>; def VOP_I32_BF16_I32_F32 : VOPProfile<[i32, bf16, i32, f32]>; def VOP_I32_F16_I32_F32 : VOPProfile<[i32, f16, i32, f32]>; +def VOP_V8F16_V2I32_I32 : VOPProfile<[v8f16, v2i32, i32, untyped]>; +def VOP_V8BF16_V2I32_I32 : VOPProfile<[v8bf16, v2i32, i32, untyped]>; +def VOP_V8F16_I32_I32 : VOPProfile<[v8f16, i32, i32, untyped]>; +def VOP_V8BF16_I32_I32 : VOPProfile<[v8bf16, i32, i32, untyped]>; +def VOP_V16F32_V3I32_I32 : VOPProfile<[v16f32, v3i32, i32, untyped]>; +def VOP_V8F32_V2I32_I32 : VOPProfile<[v8f32, v2i32, i32, untyped]>; +def VOP_V8F32_I32_I32 : VOPProfile<[v8f32, i32, i32, untyped]>; def VOP_I32_F32_I32_F32 : VOPProfile<[i32, f32, i32, f32]>; def VOP_V6I32_V32BF16_I32_F32 : VOPProfile<[v6i32, v32bf16, i32, f32]>; diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp index 5827f18..65fa088 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp @@ -1548,6 +1548,42 @@ bool shouldEmitConstantsToTextSection(const Triple &TT) { return TT.getArch() == Triple::r600; } +static bool isValidRegPrefix(char C) { + return C == 'v' || C == 's' || C == 'a'; +} + +std::tuple<char, unsigned, unsigned> +parseAsmConstraintPhysReg(StringRef Constraint) { + StringRef RegName = Constraint; + if (!RegName.consume_front("{") || !RegName.consume_back("}")) + return {}; + + char Kind = RegName.front(); + if (!isValidRegPrefix(Kind)) + return {}; + + RegName = RegName.drop_front(); + if (RegName.consume_front("[")) { + unsigned Idx, End; + bool Failed = RegName.consumeInteger(10, Idx); + Failed |= !RegName.consume_front(":"); + Failed |= RegName.consumeInteger(10, End); + Failed |= !RegName.consume_back("]"); + if (!Failed) { + unsigned NumRegs = End - Idx + 1; + if (NumRegs > 1) + return {Kind, Idx, NumRegs}; + } + } else { + unsigned Idx; + bool Failed = RegName.getAsInteger(10, Idx); + if (!Failed) + return {Kind, Idx, 1}; + } + + return {}; +} + std::pair<unsigned, unsigned> getIntegerPairAttribute(const Function &F, StringRef Name, std::pair<unsigned, unsigned> Default, diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h index 74d59f4..1252e35 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h @@ -1012,6 +1012,12 @@ bool isReadOnlySegment(const GlobalValue *GV); /// target triple \p TT, false otherwise. bool shouldEmitConstantsToTextSection(const Triple &TT); +/// Returns a valid charcode or 0 in the first entry if this is a valid physical +/// register constraint. Followed by the start register number, and the register +/// width. Does not validate the number of registers exists in the class. +std::tuple<char, unsigned, unsigned> +parseAsmConstraintPhysReg(StringRef Constraint); + /// \returns Integer value requested using \p F's \p Name attribute. /// /// \returns \p Default if attribute is not present. diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 2d3caec..1ffe39d 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -625,8 +625,9 @@ def shl_0_to_4 : PatFrag< }]; } -def VOP3_CVT_PK_F8_F32_Profile : VOP3_Profile<VOP_I32_F32_F32, VOP3_OPSEL> { - defvar Tail = (ins VGPR_32:$vdst_in, op_sel0:$op_sel); +class VOP3_CVT_PK_F8_F32_Profile<bit _HasClamp = 0> : VOP3_Profile<VOP_I32_F32_F32, VOP3_OPSEL> { + defvar Tail = !con(!if(_HasClamp, (ins Clamp:$clamp), (ins)), + (ins VGPR_32:$vdst_in, op_sel0:$op_sel)); let InsVOP3OpSel = !con(getIns64<Src0RC64, Src1RC64, Src2RC64, NumSrcArgs, 0, HasModifiers, HasSrc2Mods, HasOMod, Src0Mod, Src1Mod, Src2Mod>.ret, @@ -636,12 +637,13 @@ def VOP3_CVT_PK_F8_F32_Profile : VOP3_Profile<VOP_I32_F32_F32, VOP3_OPSEL> { HasSrc2Mods, HasOMod, Src0ModVOP3DPP, Src1ModVOP3DPP, Src2ModVOP3DPP, false>.ret, Tail); - let HasClamp = 0; + let HasClamp = _HasClamp; let HasExtVOP3DPP = 1; } -def VOP3_CVT_PK_F8_F32_Profile_fake16 : VOP3_Profile_Fake16<VOP_I16_F32_F32, VOP3_OPSEL> { - defvar Tail = (ins VGPR_32:$vdst_in, op_sel0:$op_sel); +class VOP3_CVT_PK_F8_F32_Profile_fake16<bit _HasClamp = 0> : VOP3_Profile_Fake16<VOP_I16_F32_F32, VOP3_OPSEL> { + defvar Tail = !con(!if(_HasClamp, (ins Clamp:$clamp), (ins)), + (ins VGPR_32:$vdst_in, op_sel0:$op_sel)); let InsVOP3OpSel = !con(getIns64<Src0RC64, Src1RC64, Src2RC64, NumSrcArgs, 0, HasModifiers, HasSrc2Mods, HasOMod, Src0Mod, Src1Mod, Src2Mod>.ret, @@ -651,14 +653,15 @@ def VOP3_CVT_PK_F8_F32_Profile_fake16 : VOP3_Profile_Fake16<VOP_I16_F32_F32, VOP HasSrc2Mods, HasOMod, Src0ModVOP3DPP, Src1ModVOP3DPP, Src2ModVOP3DPP, false>.ret, Tail); - let HasClamp = 0; + let HasClamp = _HasClamp; let HasExtVOP3DPP = 1; } // This t16 profile with vdst_in operand is for backward compatibility and is used // for user controlled packing -def VOP3_CVT_PK_F8_F32_Profile_t16 : VOP3_Profile_True16<VOP_I16_F32_F32, VOP3_OPSEL> { - defvar Tail = (ins VGPR_16:$vdst_in, op_sel0:$op_sel); +class VOP3_CVT_PK_F8_F32_Profile_t16<bit _HasClamp = 0> : VOP3_Profile_True16<VOP_I16_F32_F32, VOP3_OPSEL> { + defvar Tail = !con(!if(_HasClamp, (ins Clamp:$clamp), (ins)), + (ins VGPR_16:$vdst_in, op_sel0:$op_sel)); let InsVOP3OpSel = !con(getIns64<Src0RC64, Src1RC64, Src2RC64, NumSrcArgs, 0, HasModifiers, HasSrc2Mods, HasOMod, Src0Mod, Src1Mod, Src2Mod>.ret, @@ -668,7 +671,7 @@ def VOP3_CVT_PK_F8_F32_Profile_t16 : VOP3_Profile_True16<VOP_I16_F32_F32, VOP3_O HasSrc2Mods, HasOMod, Src0ModVOP3DPP, Src1ModVOP3DPP, Src2ModVOP3DPP, false>.ret, Tail); - let HasClamp = 0; + let HasClamp = _HasClamp; let HasExtVOP3DPP = 1; } @@ -702,10 +705,10 @@ def VOP3_CVT_SR_F8_F32_Profile : VOP3_Profile<VOPProfile<[i32, f32, i32, f32]>, HasModifiers, DstVT>.ret); } -class VOP3_CVT_SR_F8_ByteSel_Profile<ValueType SrcVT> : +class VOP3_CVT_SR_F8_ByteSel_Profile<ValueType SrcVT, bit _HasClamp = 0> : VOP3_Profile<VOPProfile<[i32, SrcVT, i32, untyped]>> { let HasFP8DstByteSel = 1; - let HasClamp = 0; + let HasClamp = _HasClamp; } def IsPow2Plus1: PatLeaf<(i32 imm), [{ @@ -780,15 +783,23 @@ defm V_LSHL_ADD_U64 : VOP3Inst <"v_lshl_add_u64", V_LSHL_ADD_U64_PROF>; let OtherPredicates = [HasFP8ConversionInsts], mayRaiseFPException = 0, SchedRW = [WriteFloatCvt] in { let Constraints = "$vdst = $vdst_in", DisableEncoding = "$vdst_in" in { - defm V_CVT_PK_FP8_F32 : VOP3Inst_t16_with_profiles<"v_cvt_pk_fp8_f32", VOP3_CVT_PK_F8_F32_Profile, - VOP3_CVT_PK_F8_F32_Profile_t16, - VOP3_CVT_PK_F8_F32_Profile_fake16>; - defm V_CVT_PK_BF8_F32 : VOP3Inst_t16_with_profiles<"v_cvt_pk_bf8_f32", VOP3_CVT_PK_F8_F32_Profile, - VOP3_CVT_PK_F8_F32_Profile_t16, - VOP3_CVT_PK_F8_F32_Profile_fake16>; + let OtherPredicates = [HasFP8ConversionInsts, NotHasFP8E5M3Insts] in + defm V_CVT_PK_FP8_F32 : VOP3Inst_t16_with_profiles<"v_cvt_pk_fp8_f32", VOP3_CVT_PK_F8_F32_Profile<>, + VOP3_CVT_PK_F8_F32_Profile_t16<>, + VOP3_CVT_PK_F8_F32_Profile_fake16<>>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in + defm V_CVT_PK_FP8_F32_gfx1250 : VOP3Inst_t16_with_profiles<"v_cvt_pk_fp8_f32_gfx1250", VOP3_CVT_PK_F8_F32_Profile<true>, + VOP3_CVT_PK_F8_F32_Profile_t16<true>, + VOP3_CVT_PK_F8_F32_Profile_fake16<true>>; + defm V_CVT_PK_BF8_F32 : VOP3Inst_t16_with_profiles<"v_cvt_pk_bf8_f32", VOP3_CVT_PK_F8_F32_Profile<>, + VOP3_CVT_PK_F8_F32_Profile_t16<>, + VOP3_CVT_PK_F8_F32_Profile_fake16<>>; let SubtargetPredicate = isGFX12Plus in { - defm V_CVT_SR_FP8_F32_gfx12 : VOP3Inst<"v_cvt_sr_fp8_f32_gfx12", VOP3_CVT_SR_F8_ByteSel_Profile<f32>>; + let OtherPredicates = [HasFP8ConversionInsts, NotHasFP8E5M3Insts] in + defm V_CVT_SR_FP8_F32_gfx12 : VOP3Inst<"v_cvt_sr_fp8_f32_gfx12", VOP3_CVT_SR_F8_ByteSel_Profile<f32>>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in + defm V_CVT_SR_FP8_F32_gfx1250 : VOP3Inst<"v_cvt_sr_fp8_f32_gfx1250", VOP3_CVT_SR_F8_ByteSel_Profile<f32, true>>; defm V_CVT_SR_BF8_F32_gfx12 : VOP3Inst<"v_cvt_sr_bf8_f32_gfx12", VOP3_CVT_SR_F8_ByteSel_Profile<f32>>; } } @@ -807,6 +818,11 @@ class Cvt_PK_F8_F32_Pat<SDPatternOperator node, int index, VOP3_Pseudo inst> : G (inst !if(index, SRCMODS.DST_OP_SEL, 0), $src0, 0, $src1, $old, 0) >; +class Cvt_PK_F8_F32_E5M3_Pat<SDPatternOperator node, int index, VOP3_Pseudo inst, int Clamp> : GCNPat< + (i32 (node f32:$src0, f32:$src1, i32:$old, index)), + (inst !if(index, SRCMODS.DST_OP_SEL, 0), $src0, 0, $src1, Clamp, $old, 0) +>; + multiclass Cvt_PK_F8_F32_t16_Pat<SDPatternOperator node, VOP3_Pseudo inst> { def : GCNPat< (i32 (node f32:$src0, f32:$src1, i32:$old, -1)), @@ -822,6 +838,21 @@ def : GCNPat< >; } +multiclass Cvt_PK_F8_F32_E5M3_t16_Pat<SDPatternOperator node, VOP3_Pseudo inst, int Clamp> { +def : GCNPat< + (i32 (node f32:$src0, f32:$src1, i32:$old, -1)), + (REG_SEQUENCE VGPR_32, + (i16 (EXTRACT_SUBREG $old, lo16)), lo16, + (i16 (inst SRCMODS.DST_OP_SEL, $src0, 0, $src1, Clamp, (i16 (EXTRACT_SUBREG $old, hi16)), 0)), hi16) +>; +def : GCNPat< + (i32 (node f32:$src0, f32:$src1, i32:$old, 0)), + (REG_SEQUENCE VGPR_32, + (i16 (inst 0, $src0, 0, $src1, Clamp, (i16 (EXTRACT_SUBREG $old, lo16)), 0)), lo16, + (i16 (EXTRACT_SUBREG $old, hi16)), hi16) +>; +} + class Cvt_SR_F8_F32_Pat<SDPatternOperator node, bits<2> index, VOP3_Pseudo inst> : GCNPat< (i32 (node f32:$src0, i32:$src1, i32:$old, index)), (inst !if(index{1}, SRCMODS.DST_OP_SEL, 0), $src0, 0, $src1, @@ -834,21 +865,37 @@ class Cvt_SR_F8_ByteSel_Pat<SDPatternOperator node, VOP3_Pseudo inst, ValueType (inst $src0_modifiers, $src0, $src1_modifiers, $src1, $old, (as_i32timm $byte_sel)) >; +class Cvt_SR_F8_ByteSel_E5M3_Pat<SDPatternOperator node, VOP3_Pseudo inst, + ValueType SrcVT, int Clamp> : GCNPat< + (i32 (node (VOP3Mods SrcVT:$src0, i32:$src0_modifiers), (VOP3Mods i32:$src1, i32:$src1_modifiers), + i32:$old, timm:$byte_sel)), + (inst $src0_modifiers, $src0, $src1_modifiers, $src1, Clamp, $old, (as_i32timm $byte_sel)) +>; + let OtherPredicates = [HasFP8ConversionInsts] in { foreach Index = [0, -1] in { let True16Predicate = NotHasTrue16BitInsts in { - def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_fp8_f32, Index, V_CVT_PK_FP8_F32_e64>; + let OtherPredicates = [HasFP8ConversionInsts, NotHasFP8E5M3Insts] in + def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_fp8_f32, Index, V_CVT_PK_FP8_F32_e64>; def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_bf8_f32, Index, V_CVT_PK_BF8_F32_e64>; } let True16Predicate = UseFakeTrue16Insts in { def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_fp8_f32, Index, V_CVT_PK_FP8_F32_fake16_e64>; def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_bf8_f32, Index, V_CVT_PK_BF8_F32_fake16_e64>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in { + def : Cvt_PK_F8_F32_E5M3_Pat<int_amdgcn_cvt_pk_fp8_f32, Index, V_CVT_PK_FP8_F32_gfx1250_fake16_e64, DSTCLAMP.NONE>; + def : Cvt_PK_F8_F32_E5M3_Pat<int_amdgcn_cvt_pk_fp8_f32_e5m3, Index, V_CVT_PK_FP8_F32_gfx1250_fake16_e64, DSTCLAMP.ENABLE>; + } } } let True16Predicate = UseRealTrue16Insts in { defm : Cvt_PK_F8_F32_t16_Pat<int_amdgcn_cvt_pk_fp8_f32, V_CVT_PK_FP8_F32_t16_e64>; defm : Cvt_PK_F8_F32_t16_Pat<int_amdgcn_cvt_pk_bf8_f32, V_CVT_PK_BF8_F32_t16_e64>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in { + defm : Cvt_PK_F8_F32_E5M3_t16_Pat<int_amdgcn_cvt_pk_fp8_f32, V_CVT_PK_FP8_F32_gfx1250_t16_e64, DSTCLAMP.NONE>; + defm : Cvt_PK_F8_F32_E5M3_t16_Pat<int_amdgcn_cvt_pk_fp8_f32_e5m3, V_CVT_PK_FP8_F32_gfx1250_t16_e64, DSTCLAMP.ENABLE>; + } } let SubtargetPredicate = isGFX940Plus in { @@ -859,7 +906,12 @@ let SubtargetPredicate = isGFX940Plus in { } let SubtargetPredicate = isGFX12Plus in { - def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_fp8_f32, V_CVT_SR_FP8_F32_gfx12_e64, f32>; + let OtherPredicates = [HasFP8ConversionInsts, NotHasFP8E5M3Insts] in + def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_fp8_f32, V_CVT_SR_FP8_F32_gfx12_e64, f32>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in { + def : Cvt_SR_F8_ByteSel_E5M3_Pat<int_amdgcn_cvt_sr_fp8_f32, V_CVT_SR_FP8_F32_gfx1250_e64, f32, DSTCLAMP.NONE>; + def : Cvt_SR_F8_ByteSel_E5M3_Pat<int_amdgcn_cvt_sr_fp8_f32_e5m3, V_CVT_SR_FP8_F32_gfx1250_e64, f32, DSTCLAMP.ENABLE>; + } def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_bf8_f32, V_CVT_SR_BF8_F32_gfx12_e64, f32>; } } @@ -1623,6 +1675,23 @@ let SubtargetPredicate = HasBF16ConversionInsts in { (V_CVT_PK_BF16_F32_e64 $src0_modifiers, $src0, 0, (f32 (IMPLICIT_DEF)))>; } +class VOP3_CVT_SCALE_PK_F16_F864_Profile<VOPProfile P> : VOP3_CVT_SCALEF32_PK_F864_Profile<P> { + let Src0RC64 = getVOP3VRegSrcForVT<Src0VT>.ret; + let Ins64 = !con(getIns64<Src0RC64, Src1RC64, Src2RC64, NumSrcArgs, + HasClamp, HasModifiers, HasSrc2Mods, + HasOMod, Src0Mod, Src1Mod, Src2Mod>.ret, + (ins ScaleSel:$scale_sel)); + let Asm64 = getAsmVOP3Base<NumSrcArgs, HasDst, HasClamp, + HasOpSel, HasOMod, IsVOP3P, HasNeg, HasSrc0Mods, HasSrc1Mods, + HasSrc2Mods, DstVT>.ret # "$scale_sel"; +} + +multiclass VOP3CvtScaleSelInst<string OpName, VOPProfile P, SDPatternOperator node> { + def _e64 : VOP3InstBase<OpName, VOP3_CVT_SCALE_PK_F16_F864_Profile<P>> { + let Pattern = [(set P.DstVT:$vdst, (node (P.Src0VT (VOP3Mods0 P.Src0VT:$src0)), i32:$src1, i32:$scale_sel))]; + } +} + let Src0RC64 = VSrc_NoInline_v2f16 in { def VOP3_CVT_PK_F8_F16_Profile : VOP3_Profile<VOP_I16_V2F16>; def VOP3_CVT_PK_F8_F16_True16_Profile : VOP3_Profile_True16<VOP3_CVT_PK_F8_F16_Profile>; @@ -1650,6 +1719,8 @@ def VOP3_CVT_SR_F8_F16_Fake16_Profile : VOP3_Profile_Fake16<VOP3_CVT_SR_F8_F16_P let SubtargetPredicate = isGFX1250Plus in { let ReadsModeReg = 0 in { + defm V_CVT_SR_PK_F16_F32 : VOP3Inst<"v_cvt_sr_pk_f16_f32", VOP3_Profile<VOP_V2F16_F32_F32_I32>, int_amdgcn_cvt_sr_pk_f16_f32>; + // These instructions have non-standard use of op_sel. They are using bits 2 and 3 of opsel // to select a byte in the vdst. Bits 0 and 1 are unused. let Constraints = "$vdst = $vdst_in", DisableEncoding = "$vdst_in" in { @@ -1658,6 +1729,19 @@ let SubtargetPredicate = isGFX1250Plus in { defm V_CVT_SR_BF8_F16 : VOP3Inst_t16_with_profiles<"v_cvt_sr_bf8_f16", VOP3_CVT_SR_F8_F16_Profile, VOP3_CVT_SR_F8_F16_True16_Profile, VOP3_CVT_SR_F8_F16_Fake16_Profile>; } + + let Constraints = "@earlyclobber $vdst" in { + defm V_CVT_SCALE_PK8_F16_FP8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f16_fp8", VOP_V8F16_V2I32_I32, int_amdgcn_cvt_scale_pk8_f16_fp8>; + defm V_CVT_SCALE_PK8_BF16_FP8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_bf16_fp8", VOP_V8BF16_V2I32_I32, int_amdgcn_cvt_scale_pk8_bf16_fp8>; + defm V_CVT_SCALE_PK8_F16_BF8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f16_bf8", VOP_V8F16_V2I32_I32, int_amdgcn_cvt_scale_pk8_f16_bf8>; + defm V_CVT_SCALE_PK8_BF16_BF8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_bf16_bf8", VOP_V8BF16_V2I32_I32, int_amdgcn_cvt_scale_pk8_bf16_bf8>; + defm V_CVT_SCALE_PK8_F32_FP8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_fp8", VOP_V8F32_V2I32_I32, int_amdgcn_cvt_scale_pk8_f32_fp8>; + defm V_CVT_SCALE_PK8_F32_BF8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_bf8", VOP_V8F32_V2I32_I32, int_amdgcn_cvt_scale_pk8_f32_bf8>; + } // End Constraints = "@earlyclobber $vdst" + + defm V_CVT_SCALE_PK8_F16_FP4 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f16_fp4", VOP_V8F16_I32_I32, int_amdgcn_cvt_scale_pk8_f16_fp4>; + defm V_CVT_SCALE_PK8_BF16_FP4 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_bf16_fp4", VOP_V8BF16_I32_I32, int_amdgcn_cvt_scale_pk8_bf16_fp4>; + defm V_CVT_SCALE_PK8_F32_FP4 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_fp4", VOP_V8F32_I32_I32, int_amdgcn_cvt_scale_pk8_f32_fp4>; } // End ReadsModeReg = 0 let True16Predicate = UseRealTrue16Insts in { @@ -1890,11 +1974,6 @@ defm V_ADD_MAX_U32 : VOP3Only_Realtriple_gfx1250<0x25f>; defm V_ADD_MIN_I32 : VOP3Only_Realtriple_gfx1250<0x260>; defm V_ADD_MIN_U32 : VOP3Only_Realtriple_gfx1250<0x261>; -defm V_CVT_PK_FP8_F32 : VOP3Only_Realtriple_t16_and_fake16_gfx12<0x369, "v_cvt_pk_fp8_f32">; -defm V_CVT_PK_BF8_F32 : VOP3Only_Realtriple_t16_and_fake16_gfx12<0x36a, "v_cvt_pk_bf8_f32">; -defm V_CVT_SR_FP8_F32_gfx12 : VOP3_Realtriple_with_name_gfx12<0x36b, "V_CVT_SR_FP8_F32_gfx12", "v_cvt_sr_fp8_f32" >; -defm V_CVT_SR_BF8_F32_gfx12 : VOP3_Realtriple_with_name_gfx12<0x36c, "V_CVT_SR_BF8_F32_gfx12", "v_cvt_sr_bf8_f32">; - //===----------------------------------------------------------------------===// // GFX11, GFX12 //===----------------------------------------------------------------------===// @@ -2055,6 +2134,13 @@ defm V_AND_B16 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12<0x36 defm V_OR_B16 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12<0x363, "v_or_b16">; defm V_XOR_B16 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12<0x364, "v_xor_b16">; +defm V_CVT_PK_FP8_F32 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12_not_gfx1250<0x369, "v_cvt_pk_fp8_f32">; +defm V_CVT_PK_FP8_F32_gfx1250 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x369, "v_cvt_pk_fp8_f32">; +defm V_CVT_PK_BF8_F32 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12<0x36a, "v_cvt_pk_bf8_f32">; +defm V_CVT_SR_FP8_F32_gfx12 : VOP3_Realtriple_with_name_gfx11_gfx12_not_gfx1250<0x36b, "V_CVT_SR_FP8_F32_gfx12", "v_cvt_sr_fp8_f32">; +defm V_CVT_SR_FP8_F32_gfx1250 : VOP3Only_Realtriple_with_name_gfx1250<0x36b, "V_CVT_SR_FP8_F32_gfx1250", "v_cvt_sr_fp8_f32">; +defm V_CVT_SR_BF8_F32_gfx12 : VOP3_Realtriple_with_name_gfx11_gfx12<0x36c, "V_CVT_SR_BF8_F32_gfx12", "v_cvt_sr_bf8_f32">; + let AssemblerPredicate = isGFX11Plus in { def : AMDGPUMnemonicAlias<"v_add3_nc_u32", "v_add3_u32">; def : AMDGPUMnemonicAlias<"v_xor_add_u32", "v_xad_u32">; @@ -2064,8 +2150,19 @@ let AssemblerPredicate = isGFX11Plus in { defm V_LSHL_ADD_U64 : VOP3Only_Realtriple_gfx1250<0x252>; defm V_ASHR_PK_I8_I32 : VOP3Only_Realtriple_gfx1250<0x290>; defm V_ASHR_PK_U8_I32 : VOP3Only_Realtriple_gfx1250<0x291>; +defm V_CVT_SCALE_PK8_F16_FP4 : VOP3Only_ScaleSel_Real_gfx1250<0x29f>; +defm V_CVT_SCALE_PK8_BF16_FP4 : VOP3Only_ScaleSel_Real_gfx1250<0x2a0>; +defm V_CVT_SCALE_PK8_F32_FP4 : VOP3Only_ScaleSel_Real_gfx1250<0x2a1>; +defm V_CVT_SCALE_PK8_F16_FP8 : VOP3Only_ScaleSel_Real_gfx1250<0x2a8>; +defm V_CVT_SCALE_PK8_BF16_FP8 : VOP3Only_ScaleSel_Real_gfx1250<0x2a9>; +defm V_CVT_SCALE_PK8_F32_FP8 : VOP3Only_ScaleSel_Real_gfx1250<0x2aa>; +defm V_CVT_SCALE_PK8_F16_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ab>; +defm V_CVT_SCALE_PK8_BF16_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ac>; +defm V_CVT_SCALE_PK8_F32_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ad>; defm V_CVT_PK_BF16_F32 : VOP3Only_Realtriple_gfx1250<0x36d>; defm V_CVT_SR_PK_BF16_F32 : VOP3Only_Realtriple_gfx1250<0x36e>; +defm V_CVT_PK_F16_F32 : VOP3Only_Realtriple_gfx1250<0x36f>; +defm V_CVT_SR_PK_F16_F32 : VOP3Only_Realtriple_gfx1250<0x370>; defm V_CVT_PK_FP8_F16_gfx1250 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x372, "v_cvt_pk_fp8_f16">; defm V_CVT_PK_BF8_F16_gfx1250 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x373, "v_cvt_pk_bf8_f16">; defm V_CVT_SR_FP8_F16 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x374>; diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td index a029376..f027ab0 100644 --- a/llvm/lib/Target/AMDGPU/VOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td @@ -414,6 +414,13 @@ class VOP3a_BITOP3_gfx12<bits<10> op, VOPProfile p> : VOP3e_gfx11_gfx12<op, p> { let Inst{14} = !if(p.HasOpSel, src0_modifiers{3}, 0); } +class VOP3a_ScaleSel_gfx1250<bits<10> op, VOPProfile p> : VOP3e_gfx11_gfx12<op, p> { + bits<3> scale_sel; + + let Inst{13-11} = scale_sel; + let Inst{14} = 0; +} + class VOP3Interp_gfx10<bits<10> op, VOPProfile p> : VOP3e_gfx10<op, p> { bits<6> attr; bits<2> attrchan; @@ -2010,6 +2017,30 @@ multiclass VOP3_BITOP3_Real_Base<GFXGen Gen, bits<10> op, string asmName> { } } +multiclass VOP3Only_ScaleSel_Real_gfx1250<bits<10> op> { + defvar ps = !cast<VOP_Pseudo>(NAME#"_e64"); + def _e64_gfx1250 : + VOP3_Real_Gen<ps, GFX1250Gen>, + VOP3a_ScaleSel_gfx1250<op, ps.Pfl>; +} + +multiclass VOP3Only_Realtriple_t16_gfx11_gfx12_not_gfx1250<bits<10> op, string asmName, string opName = NAME, + string pseudo_mnemonic = "", bit isSingle = 0> : + VOP3_Realtriple_with_name<GFX11Gen, op, opName, asmName, pseudo_mnemonic, isSingle>, + VOP3_Realtriple_with_name<GFX12Not12_50Gen, op, opName, asmName, pseudo_mnemonic, isSingle>; + +multiclass VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12_not_gfx1250<bits<10> op, string asmName, + string opName = NAME, string pseudo_mnemonic = ""> { + defm _t16 : VOP3Only_Realtriple_t16_gfx11_gfx12_not_gfx1250<op, asmName, opName#"_t16", pseudo_mnemonic, 1>; + defm _fake16 : VOP3Only_Realtriple_t16_gfx11_gfx12_not_gfx1250<op, asmName, opName#"_fake16", pseudo_mnemonic, 1>; +} + +multiclass VOP3_Realtriple_with_name_gfx11_gfx12_not_gfx1250<bits<10> op, string opName, + string asmName, string pseudo_mnemonic = "", + bit isSingle = 0> : + VOP3_Realtriple_with_name<GFX11Gen, op, opName, asmName, pseudo_mnemonic, isSingle>, + VOP3_Realtriple_with_name<GFX12Not12_50Gen, op, opName, asmName, pseudo_mnemonic, isSingle>; + //===----------------------------------------------------------------------===// // VOP3 GFX11 //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.cpp b/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.cpp index 146fc67..dfa3de3c 100644 --- a/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.cpp +++ b/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.cpp @@ -1125,7 +1125,7 @@ void ARMAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, const unsigned NumBytes = getFixupKindNumBytes(Kind); unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // Used to point to big endian bytes. unsigned FullSizeBytes; diff --git a/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp b/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp index 128cc0b..38444f9 100644 --- a/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp +++ b/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp @@ -398,7 +398,7 @@ void AVRAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, Value <<= Info.TargetOffset; unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. diff --git a/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.cpp b/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.cpp index 694d9ea..1bd82fad 100644 --- a/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.cpp +++ b/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.cpp @@ -220,7 +220,7 @@ void CSKYAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned Offset = Fixup.getOffset(); unsigned NumBytes = alignTo(Info.TargetSize + Info.TargetOffset, 8) / 8; - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index ebdfcaa..a4f5086 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -17,7 +17,6 @@ #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Frontend/HLSL/RootSignatureMetadata.h" -#include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" @@ -111,14 +110,25 @@ analyzeModule(Module &M) { reportError(Ctx, "Root Element is not a metadata node."); continue; } - mcdxbc::RootSignatureDesc RSD; - if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2)) - RSD.Version = *Version; - else { + std::optional<uint32_t> V = extractMdIntValue(RSDefNode, 2); + if (!V.has_value()) { reportError(Ctx, "Invalid RSDefNode value, expected constant int"); continue; } + llvm::hlsl::rootsig::MetadataParser MDParser(RootElementListNode); + llvm::Expected<mcdxbc::RootSignatureDesc> RSDOrErr = + MDParser.ParseRootSignature(V.value()); + + if (!RSDOrErr) { + handleAllErrors(RSDOrErr.takeError(), [&](ErrorInfoBase &EIB) { + Ctx->emitError(EIB.message()); + }); + continue; + } + + auto &RSD = *RSDOrErr; + // Clang emits the root signature data in dxcontainer following a specific // sequence. First the header, then the root parameters. So the header // offset will always equal to the header size. @@ -127,12 +137,6 @@ analyzeModule(Module &M) { // static sampler offset is calculated when writting dxcontainer. RSD.StaticSamplersOffset = 0u; - hlsl::rootsig::MetadataParser MDParser(RootElementListNode); - - if (MDParser.ParseRootSignature(Ctx, RSD)) { - return RSDMap; - } - RSDMap.insert(std::make_pair(F, RSD)); } diff --git a/llvm/lib/Target/Hexagon/MCTargetDesc/HexagonAsmBackend.cpp b/llvm/lib/Target/Hexagon/MCTargetDesc/HexagonAsmBackend.cpp index 7d3074b..d5b7a75 100644 --- a/llvm/lib/Target/Hexagon/MCTargetDesc/HexagonAsmBackend.cpp +++ b/llvm/lib/Target/Hexagon/MCTargetDesc/HexagonAsmBackend.cpp @@ -669,7 +669,7 @@ void HexagonAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // to a real offset before we can use it. uint32_t Offset = Fixup.getOffset(); unsigned NumBytes = getFixupKindNumBytes(Kind); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); char *InstAddr = Data.data() + Offset; Value = adjustFixupValue(Kind, FixupValue); diff --git a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp index d9ea88c..858f3d0 100644 --- a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp +++ b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp @@ -169,7 +169,7 @@ void LoongArchAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned Offset = Fixup.getOffset(); unsigned NumBytes = alignTo(Info.TargetSize + Info.TargetOffset, 8) / 8; - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. for (unsigned I = 0; I != NumBytes; ++I) { diff --git a/llvm/lib/Target/M68k/MCTargetDesc/M68kAsmBackend.cpp b/llvm/lib/Target/M68k/MCTargetDesc/M68kAsmBackend.cpp index 5e03903..7ef705d 100644 --- a/llvm/lib/Target/M68k/MCTargetDesc/M68kAsmBackend.cpp +++ b/llvm/lib/Target/M68k/MCTargetDesc/M68kAsmBackend.cpp @@ -85,7 +85,7 @@ void M68kAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, Asm->getWriter().recordRelocation(F, Fixup, Target, Value); unsigned Size = 1 << getFixupKindLog2Size(Fixup.getKind()); - assert(Fixup.getOffset() + Size <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + Size <= F.getSize() && "Invalid fixup offset!"); // Check that uppper bits are either all zeros or all ones. // Specifically ignore overflow/underflow as long as the leakage is // limited to the lower bits. This is to remain compatible with diff --git a/llvm/lib/Target/MSP430/MCTargetDesc/MSP430AsmBackend.cpp b/llvm/lib/Target/MSP430/MCTargetDesc/MSP430AsmBackend.cpp index 29e5bfa..b513503 100644 --- a/llvm/lib/Target/MSP430/MCTargetDesc/MSP430AsmBackend.cpp +++ b/llvm/lib/Target/MSP430/MCTargetDesc/MSP430AsmBackend.cpp @@ -120,7 +120,7 @@ void MSP430AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned Offset = Fixup.getOffset(); unsigned NumBytes = alignTo(Info.TargetSize + Info.TargetOffset, 8) / 8; - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index 8eec915..ee1ca45 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -391,16 +391,6 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, } } -void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum, - raw_ostream &O) { - auto &Op = MI->getOperand(OpNum); - assert(Op.isImm() && "Invalid operand"); - if (Op.getImm() != 0) { - O << "+"; - printOperand(MI, OpNum, O); - } -} - void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O) { int64_t Imm = MI->getOperand(OpNum).getImm(); diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h index c3ff346..92155b0 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h @@ -46,7 +46,6 @@ public: StringRef Modifier = {}); void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O, StringRef Modifier = {}); - void printOffseti32imm(const MCInst *MI, int OpNum, raw_ostream &O); void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O); void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O); void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O); diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp index cd40481..a349609 100644 --- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp @@ -56,15 +56,12 @@ static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI, case NVPTX::LD_i16: case NVPTX::LD_i32: case NVPTX::LD_i64: - case NVPTX::LD_i8: case NVPTX::LDV_i16_v2: case NVPTX::LDV_i16_v4: case NVPTX::LDV_i32_v2: case NVPTX::LDV_i32_v4: case NVPTX::LDV_i64_v2: - case NVPTX::LDV_i64_v4: - case NVPTX::LDV_i8_v2: - case NVPTX::LDV_i8_v4: { + case NVPTX::LDV_i64_v4: { LoadInsts.push_back(&U); return true; } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 95abcde..6068035 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1003,14 +1003,10 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) { // Helper function template to reduce amount of boilerplate code for // opcode selection. static std::optional<unsigned> -pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8, - std::optional<unsigned> Opcode_i16, +pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i16, std::optional<unsigned> Opcode_i32, std::optional<unsigned> Opcode_i64) { switch (VT) { - case MVT::i1: - case MVT::i8: - return Opcode_i8; case MVT::f16: case MVT::i16: case MVT::bf16: @@ -1078,8 +1074,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { Chain}; const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; - const std::optional<unsigned> Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_i8, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64); + const std::optional<unsigned> Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64); if (!Opcode) return false; @@ -1164,17 +1160,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { default: llvm_unreachable("Unexpected opcode"); case NVPTXISD::LoadV2: - Opcode = - pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v2, NVPTX::LDV_i16_v2, - NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2); + Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v2, + NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2); break; case NVPTXISD::LoadV4: - Opcode = - pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v4, NVPTX::LDV_i16_v4, - NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4); + Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v4, + NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4); break; case NVPTXISD::LoadV8: - Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i8 */}, {/* no v8i16 */}, + Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i16 */}, NVPTX::LDV_i32_v8, {/* no v8i64 */}); break; } @@ -1230,22 +1224,21 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { default: llvm_unreachable("Unexpected opcode"); case ISD::LOAD: - Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i8, - NVPTX::LD_GLOBAL_NC_i16, NVPTX::LD_GLOBAL_NC_i32, - NVPTX::LD_GLOBAL_NC_i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16, + NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64); break; case NVPTXISD::LoadV2: - Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_GLOBAL_NC_v2i8, NVPTX::LD_GLOBAL_NC_v2i16, - NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64); + Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16, + NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64); break; case NVPTXISD::LoadV4: - Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_GLOBAL_NC_v4i8, NVPTX::LD_GLOBAL_NC_v4i16, - NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64); + Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v4i16, + NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64); break; case NVPTXISD::LoadV8: - Opcode = pickOpcodeForVT(TargetVT, {/* no v8i8 */}, {/* no v8i16 */}, + Opcode = pickOpcodeForVT(TargetVT, {/* no v8i16 */}, NVPTX::LD_GLOBAL_NC_v8i32, {/* no v8i64 */}); break; } @@ -1276,8 +1269,9 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) { break; } - const MVT::SimpleValueType SelectVT = - MVT::getIntegerVT(LD->getMemoryVT().getSizeInBits() / NumElts).SimpleTy; + SDLoc DL(N); + const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts; + const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; // If this is an LDU intrinsic, the address is the third operand. If its an // LDU SD node (from custom vector handling), then its the second operand @@ -1286,32 +1280,28 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) { SDValue Base, Offset; SelectADDR(Addr, Base, Offset); - SDValue Ops[] = {Base, Offset, LD->getChain()}; + SDValue Ops[] = {getI32Imm(FromTypeWidth, DL), Base, Offset, LD->getChain()}; std::optional<unsigned> Opcode; switch (N->getOpcode()) { default: llvm_unreachable("Unexpected opcode"); case ISD::INTRINSIC_W_CHAIN: - Opcode = - pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_i8, NVPTX::LDU_GLOBAL_i16, - NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_i16, + NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64); break; case NVPTXISD::LDUV2: - Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v2i8, - NVPTX::LDU_GLOBAL_v2i16, NVPTX::LDU_GLOBAL_v2i32, - NVPTX::LDU_GLOBAL_v2i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v2i16, + NVPTX::LDU_GLOBAL_v2i32, NVPTX::LDU_GLOBAL_v2i64); break; case NVPTXISD::LDUV4: - Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v4i8, - NVPTX::LDU_GLOBAL_v4i16, NVPTX::LDU_GLOBAL_v4i32, - {/* no v4i64 */}); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v4i16, + NVPTX::LDU_GLOBAL_v4i32, {/* no v4i64 */}); break; } if (!Opcode) return false; - SDLoc DL(N); SDNode *NVPTXLDU = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops); ReplaceNode(LD, NVPTXLDU); @@ -1362,8 +1352,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { Chain}; const std::optional<unsigned> Opcode = - pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i8, - NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64); + pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i16, + NVPTX::ST_i32, NVPTX::ST_i64); if (!Opcode) return false; @@ -1423,16 +1413,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { default: return false; case NVPTXISD::StoreV2: - Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v2, NVPTX::STV_i16_v2, - NVPTX::STV_i32_v2, NVPTX::STV_i64_v2); + Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v2, NVPTX::STV_i32_v2, + NVPTX::STV_i64_v2); break; case NVPTXISD::StoreV4: - Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v4, NVPTX::STV_i16_v4, - NVPTX::STV_i32_v4, NVPTX::STV_i64_v4); + Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, + NVPTX::STV_i64_v4); break; case NVPTXISD::StoreV8: - Opcode = pickOpcodeForVT(EltVT, {/* no v8i8 */}, {/* no v8i16 */}, - NVPTX::STV_i32_v8, {/* no v8i64 */}); + Opcode = pickOpcodeForVT(EltVT, {/* no v8i16 */}, NVPTX::STV_i32_v8, + {/* no v8i64 */}); break; } @@ -1687,10 +1677,11 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) { auto API = APF.bitcastToAPInt(); API = API.concat(API); auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32); - return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32i, DL, VT, Const), 0); + return SDValue(CurDAG->getMachineNode(NVPTX::MOV_B32_i, DL, VT, Const), + 0); } auto Const = CurDAG->getTargetConstantFP(APF, DL, VT); - return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16i, DL, VT, Const), 0); + return SDValue(CurDAG->getMachineNode(NVPTX::MOV_BF16_i, DL, VT, Const), 0); }; switch (N->getOpcode()) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 4fd3623..65d1be3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -4917,7 +4917,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { return SDValue(); auto *LD = cast<MemSDNode>(N); - EVT MemVT = LD->getMemoryVT(); SDLoc DL(LD); // the new opcode after we double the number of operands @@ -4958,9 +4957,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end()); // Create the new load - SDValue NewLoad = - DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs), - Operands, MemVT, LD->getMemOperand()); + SDValue NewLoad = DCI.DAG.getMemIntrinsicNode( + Opcode, DL, DCI.DAG.getVTList(NewVTs), Operands, LD->getMemoryVT(), + LD->getMemOperand()); // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep // the outputs the same. These nodes will be optimized away in later @@ -5002,7 +5001,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N, return SDValue(); auto *ST = cast<MemSDNode>(N); - EVT MemVT = ElementVT.getVectorElementType(); // The new opcode after we double the number of operands. NVPTXISD::NodeType Opcode; @@ -5011,11 +5009,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N, // Any packed type is legal, so the legalizer will not have lowered // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do // it here. - MemVT = ST->getMemoryVT(); Opcode = NVPTXISD::StoreV2; break; case NVPTXISD::StoreV2: - MemVT = ST->getMemoryVT(); Opcode = NVPTXISD::StoreV4; break; case NVPTXISD::StoreV4: @@ -5066,7 +5062,7 @@ static SDValue combinePackingMovIntoStore(SDNode *N, // Now we replace the store return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands, - MemVT, ST->getMemOperand()); + ST->getMemoryVT(), ST->getMemOperand()); } static SDValue PerformStoreCombine(SDNode *N, diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td index 86dcb4a..719be03 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td @@ -11,15 +11,9 @@ // //===----------------------------------------------------------------------===// -// Vector instruction type enum -class VecInstTypeEnum<bits<4> val> { - bits<4> Value=val; -} -def VecNOP : VecInstTypeEnum<0>; - // Generic NVPTX Format -class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern> +class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern = []> : Instruction { field bits<14> Inst; @@ -30,7 +24,6 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern> let Pattern = pattern; // TSFlagFields - bits<4> VecInstType = VecNOP.Value; bit IsLoad = false; bit IsStore = false; @@ -45,7 +38,6 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern> // 2**(2-1) = 2. bits<2> IsSuld = 0; - let TSFlags{3...0} = VecInstType; let TSFlags{4} = IsLoad; let TSFlags{5} = IsStore; let TSFlags{6} = IsTex; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp index e218ef1..34fe467 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -35,23 +35,23 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB, const TargetRegisterClass *DestRC = MRI.getRegClass(DestReg); const TargetRegisterClass *SrcRC = MRI.getRegClass(SrcReg); - if (RegInfo.getRegSizeInBits(*DestRC) != RegInfo.getRegSizeInBits(*SrcRC)) + if (DestRC != SrcRC) report_fatal_error("Copy one register into another with a different width"); unsigned Op; - if (DestRC == &NVPTX::B1RegClass) { - Op = NVPTX::IMOV1r; - } else if (DestRC == &NVPTX::B16RegClass) { - Op = NVPTX::MOV16r; - } else if (DestRC == &NVPTX::B32RegClass) { - Op = NVPTX::IMOV32r; - } else if (DestRC == &NVPTX::B64RegClass) { - Op = NVPTX::IMOV64r; - } else if (DestRC == &NVPTX::B128RegClass) { - Op = NVPTX::IMOV128r; - } else { + if (DestRC == &NVPTX::B1RegClass) + Op = NVPTX::MOV_B1_r; + else if (DestRC == &NVPTX::B16RegClass) + Op = NVPTX::MOV_B16_r; + else if (DestRC == &NVPTX::B32RegClass) + Op = NVPTX::MOV_B32_r; + else if (DestRC == &NVPTX::B64RegClass) + Op = NVPTX::MOV_B64_r; + else if (DestRC == &NVPTX::B128RegClass) + Op = NVPTX::MOV_B128_r; + else llvm_unreachable("Bad register copy"); - } + BuildMI(MBB, I, DL, get(Op), DestReg) .addReg(SrcReg, getKillRegState(KillSrc)); } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 6000b40..d8047d3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -15,19 +15,8 @@ include "NVPTXInstrFormats.td" let OperandType = "OPERAND_IMMEDIATE" in { def f16imm : Operand<f16>; def bf16imm : Operand<bf16>; - } -// List of vector specific properties -def isVecLD : VecInstTypeEnum<1>; -def isVecST : VecInstTypeEnum<2>; -def isVecBuild : VecInstTypeEnum<3>; -def isVecShuffle : VecInstTypeEnum<4>; -def isVecExtract : VecInstTypeEnum<5>; -def isVecInsert : VecInstTypeEnum<6>; -def isVecDest : VecInstTypeEnum<7>; -def isVecOther : VecInstTypeEnum<15>; - //===----------------------------------------------------------------------===// // NVPTX Operand Definitions. //===----------------------------------------------------------------------===// @@ -484,46 +473,28 @@ let hasSideEffects = false in { // takes a CvtMode immediate that defines the conversion mode to use. It can // be CvtNONE to omit a conversion mode. multiclass CVT_FROM_ALL<string ToType, RegisterClass RC, list<Predicate> Preds = []> { - def _s8 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s8">, - Requires<Preds>; - def _u8 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u8">, - Requires<Preds>; - def _s16 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s16">, - Requires<Preds>; - def _u16 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u16">, - Requires<Preds>; - def _s32 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B32:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s32">, - Requires<Preds>; - def _u32 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B32:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u32">, - Requires<Preds>; - def _s64 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B64:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s64">, - Requires<Preds>; - def _u64 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B64:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u64">, - Requires<Preds>; + foreach sign = ["s", "u"] in { + def _ # sign # "8" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B16:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "8">, + Requires<Preds>; + def _ # sign # "16" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B16:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "16">, + Requires<Preds>; + def _ # sign # "32" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B32:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "32">, + Requires<Preds>; + def _ # sign # "64" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B64:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "64">, + Requires<Preds>; + } def _f16 : BasicFlagsNVPTXInst<(outs RC:$dst), (ins B16:$src), (ins CvtMode:$mode), @@ -554,14 +525,12 @@ let hasSideEffects = false in { } // Generate cvts from all types to all types. - defm CVT_s8 : CVT_FROM_ALL<"s8", B16>; - defm CVT_u8 : CVT_FROM_ALL<"u8", B16>; - defm CVT_s16 : CVT_FROM_ALL<"s16", B16>; - defm CVT_u16 : CVT_FROM_ALL<"u16", B16>; - defm CVT_s32 : CVT_FROM_ALL<"s32", B32>; - defm CVT_u32 : CVT_FROM_ALL<"u32", B32>; - defm CVT_s64 : CVT_FROM_ALL<"s64", B64>; - defm CVT_u64 : CVT_FROM_ALL<"u64", B64>; + foreach sign = ["s", "u"] in { + defm CVT_ # sign # "8" : CVT_FROM_ALL<sign # "8", B16>; + defm CVT_ # sign # "16" : CVT_FROM_ALL<sign # "16", B16>; + defm CVT_ # sign # "32" : CVT_FROM_ALL<sign # "32", B32>; + defm CVT_ # sign # "64" : CVT_FROM_ALL<sign # "64", B64>; + } defm CVT_f16 : CVT_FROM_ALL<"f16", B16>; defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>; defm CVT_f32 : CVT_FROM_ALL<"f32", B32>; @@ -569,18 +538,12 @@ let hasSideEffects = false in { // These cvts are different from those above: The source and dest registers // are of the same type. - def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), - "cvt.s16.s8">; - def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "cvt.s32.s8">; - def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "cvt.s32.s16">; - def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s8">; - def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s16">; - def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s32">; + def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), "cvt.s16.s8">; + def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s8">; + def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s16">; + def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s8">; + def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s16">; + def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s32">; multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> { def _f32 : @@ -782,7 +745,7 @@ defm SUB : I3<"sub.s", sub, commutative = false>; def ADD16x2 : I16x2<"add.s", add>; -// in32 and int64 addition and subtraction with carry-out. +// int32 and int64 addition and subtraction with carry-out. defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>; defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>; @@ -803,17 +766,17 @@ defm UDIV : I3<"div.u", udiv, commutative = false>; defm SREM : I3<"rem.s", srem, commutative = false>; defm UREM : I3<"rem.u", urem, commutative = false>; -// Integer absolute value. NumBits should be one minus the bit width of RC. -// This idiom implements the algorithm at -// http://graphics.stanford.edu/~seander/bithacks.html#IntegerAbs. -multiclass ABS<ValueType T, RegisterClass RC, string SizeName> { - def : BasicNVPTXInst<(outs RC:$dst), (ins RC:$a), - "abs" # SizeName, - [(set T:$dst, (abs T:$a))]>; +foreach t = [I16RT, I32RT, I64RT] in { + def ABS_S # t.Size : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a), + "abs.s" # t.Size, + [(set t.Ty:$dst, (abs t.Ty:$a))]>; + + def NEG_S # t.Size : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), + "neg.s" # t.Size, + [(set t.Ty:$dst, (ineg t.Ty:$src))]>; } -defm ABS_16 : ABS<i16, B16, ".s16">; -defm ABS_32 : ABS<i32, B32, ".s32">; -defm ABS_64 : ABS<i64, B64, ".s64">; // Integer min/max. defm SMAX : I3<"max.s", smax, commutative = true>; @@ -830,116 +793,63 @@ def UMIN16x2 : I16x2<"min.u", umin>; // // Wide multiplication // -def MULWIDES64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">; -def MULWIDES64Imm : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">; - -def MULWIDEU64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">; -def MULWIDEU64Imm : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">; - -def MULWIDES32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">; -def MULWIDES32Imm : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">; - -def MULWIDEU32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">; -def MULWIDEU32Imm : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">; def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>; -def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>; -def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>; +def smul_wide : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>; +def umul_wide : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>; -// Matchers for signed, unsigned mul.wide ISD nodes. -let Predicates = [hasOptEnabled] in { - def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>; - def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>; - def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>; - def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>; - def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>; - def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>; - def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>; - def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>; +multiclass MULWIDEInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> { + def suffix # _rr : + BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.RC:$b), + "mul.wide." # suffix, + [(set big_t.Ty:$dst, (op small_t.Ty:$a, small_t.Ty:$b))]>; + def suffix # _ri : + BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.Imm:$b), + "mul.wide." # suffix, + [(set big_t.Ty:$dst, (op small_t.Ty:$a, imm:$b))]>; } +defm MUL_WIDE : MULWIDEInst<"s32", smul_wide, I64RT, I32RT>; +defm MUL_WIDE : MULWIDEInst<"u32", umul_wide, I64RT, I32RT>; +defm MUL_WIDE : MULWIDEInst<"s16", smul_wide, I32RT, I16RT>; +defm MUL_WIDE : MULWIDEInst<"u16", umul_wide, I32RT, I16RT>; + // // Integer multiply-add // -def mul_oneuse : OneUse2<mul>; - -multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> { - def rrr: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Reg:$b, Reg:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>; - - def rir: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Imm:$b, Reg:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>; - def rri: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Reg:$b, Imm:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>; - def rii: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Imm:$b, Imm:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>; -} - -let Predicates = [hasOptEnabled] in { -defm MAD16 : MAD<"mad.lo.s16", i16, B16, i16imm>; -defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>; -defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>; -} - -multiclass MAD_WIDE<string PtxSuffix, OneUse2 Op, RegTyInfo BigT, RegTyInfo SmallT> { +multiclass MADInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> { def rrr: - BasicNVPTXInst<(outs BigT.RC:$dst), - (ins SmallT.RC:$a, SmallT.RC:$b, BigT.RC:$c), - "mad.wide." # PtxSuffix, - [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), BigT.Ty:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.RC:$b, big_t.RC:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), big_t.Ty:$c))]>; def rri: - BasicNVPTXInst<(outs BigT.RC:$dst), - (ins SmallT.RC:$a, SmallT.RC:$b, BigT.Imm:$c), - "mad.wide." # PtxSuffix, - [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), imm:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.RC:$b, big_t.Imm:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), imm:$c))]>; def rir: - BasicNVPTXInst<(outs BigT.RC:$dst), - (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.RC:$c), - "mad.wide." # PtxSuffix, - [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), BigT.Ty:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.Imm:$b, big_t.RC:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), big_t.Ty:$c))]>; def rii: - BasicNVPTXInst<(outs BigT.RC:$dst), - (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.Imm:$c), - "mad.wide." # PtxSuffix, - [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), imm:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.Imm:$b, big_t.Imm:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), imm:$c))]>; } -def mul_wide_unsigned_oneuse : OneUse2<mul_wide_unsigned>; -def mul_wide_signed_oneuse : OneUse2<mul_wide_signed>; - let Predicates = [hasOptEnabled] in { -defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned_oneuse, I32RT, I16RT>; -defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed_oneuse, I32RT, I16RT>; -defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned_oneuse, I64RT, I32RT>; -defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed_oneuse, I64RT, I32RT>; -} + defm MAD_LO_S16 : MADInst<"lo.s16", mul, I16RT, I16RT>; + defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>; + defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>; -foreach t = [I16RT, I32RT, I64RT] in { - def NEG_S # t.Size : - BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), - "neg.s" # t.Size, - [(set t.Ty:$dst, (ineg t.Ty:$src))]>; + defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>; + defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>; + defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>; + defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>; } //----------------------------------- @@ -1050,8 +960,7 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b), def FRCP32_approx_r : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$b), (ins FTZFlag:$ftz), "rcp.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>; @@ -1060,14 +969,12 @@ def FRCP32_approx_r : // def FDIV32_approx_rr : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>; def FDIV32_approx_ri : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>; // @@ -1090,14 +997,12 @@ def : Pat<(fdiv_full f32imm_1, f32:$b), // def FDIV32rr : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.full$ftz.f32", [(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>; def FDIV32ri : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.full$ftz.f32", [(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>; // @@ -1111,8 +1016,7 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b), def FRCP32r_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$b), (ins FTZFlag:$ftz), "rcp.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>; // @@ -1120,14 +1024,12 @@ def FRCP32r_prec : // def FDIV32rr_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>; def FDIV32ri_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>; @@ -1206,10 +1108,8 @@ def TANH_APPROX_f32 : // Template for three-arg bitwise operations. Takes three args, Creates .b16, // .b32, .b64, and .pred (predicate registers -- i.e., i1) versions of OpcStr. multiclass BITWISE<string OpcStr, SDNode OpNode> { - defm b1 : I3Inst<OpcStr # ".pred", OpNode, I1RT, commutative = true>; - defm b16 : I3Inst<OpcStr # ".b16", OpNode, I16RT, commutative = true>; - defm b32 : I3Inst<OpcStr # ".b32", OpNode, I32RT, commutative = true>; - defm b64 : I3Inst<OpcStr # ".b64", OpNode, I64RT, commutative = true>; + foreach t = [I1RT, I16RT, I32RT, I64RT] in + defm _ # t.PtxType : I3Inst<OpcStr # "." # t.PtxType, OpNode, t, commutative = true>; } defm OR : BITWISE<"or", or>; @@ -1217,48 +1117,40 @@ defm AND : BITWISE<"and", and>; defm XOR : BITWISE<"xor", xor>; // PTX does not support mul on predicates, convert to and instructions -def : Pat<(mul i1:$a, i1:$b), (ANDb1rr $a, $b)>; -def : Pat<(mul i1:$a, imm:$b), (ANDb1ri $a, imm:$b)>; +def : Pat<(mul i1:$a, i1:$b), (AND_predrr $a, $b)>; +def : Pat<(mul i1:$a, imm:$b), (AND_predri $a, imm:$b)>; foreach op = [add, sub] in { - def : Pat<(op i1:$a, i1:$b), (XORb1rr $a, $b)>; - def : Pat<(op i1:$a, imm:$b), (XORb1ri $a, imm:$b)>; + def : Pat<(op i1:$a, i1:$b), (XOR_predrr $a, $b)>; + def : Pat<(op i1:$a, imm:$b), (XOR_predri $a, imm:$b)>; } // These transformations were once reliably performed by instcombine, but thanks // to poison semantics they are no longer safe for LLVM IR, perform them here // instead. -def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr $a, $b)>; -def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr $a, $b)>; +def : Pat<(select i1:$a, i1:$b, 0), (AND_predrr $a, $b)>; +def : Pat<(select i1:$a, 1, i1:$b), (OR_predrr $a, $b)>; // Lower logical v2i16/v4i8 ops as bitwise ops on b32. foreach vt = [v2i16, v4i8] in { - def : Pat<(or vt:$a, vt:$b), (ORb32rr $a, $b)>; - def : Pat<(xor vt:$a, vt:$b), (XORb32rr $a, $b)>; - def : Pat<(and vt:$a, vt:$b), (ANDb32rr $a, $b)>; + def : Pat<(or vt:$a, vt:$b), (OR_b32rr $a, $b)>; + def : Pat<(xor vt:$a, vt:$b), (XOR_b32rr $a, $b)>; + def : Pat<(and vt:$a, vt:$b), (AND_b32rr $a, $b)>; // The constants get legalized into a bitcast from i32, so that's what we need // to match here. def: Pat<(or vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ORb32ri $a, imm:$b)>; + (OR_b32ri $a, imm:$b)>; def: Pat<(xor vt:$a, (vt (bitconvert (i32 imm:$b)))), - (XORb32ri $a, imm:$b)>; + (XOR_b32ri $a, imm:$b)>; def: Pat<(and vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ANDb32ri $a, imm:$b)>; -} - -def NOT1 : BasicNVPTXInst<(outs B1:$dst), (ins B1:$src), - "not.pred", - [(set i1:$dst, (not i1:$src))]>; -def NOT16 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), - "not.b16", - [(set i16:$dst, (not i16:$src))]>; -def NOT32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "not.b32", - [(set i32:$dst, (not i32:$src))]>; -def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "not.b64", - [(set i64:$dst, (not i64:$src))]>; + (AND_b32ri $a, imm:$b)>; +} + +foreach t = [I1RT, I16RT, I32RT, I64RT] in + def NOT_ # t.PtxType : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), + "not." # t.PtxType, + [(set t.Ty:$dst, (not t.Ty:$src))]>; // Template for left/right shifts. Takes three operands, // [dest (reg), src (reg), shift (reg or imm)]. @@ -1266,34 +1158,22 @@ def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), // // This template also defines a 32-bit shift (imm, imm) instruction. multiclass SHIFT<string OpcStr, SDNode OpNode> { - def i64rr : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, B32:$b), - OpcStr # "64", - [(set i64:$dst, (OpNode i64:$a, i32:$b))]>; - def i64ri : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, i32imm:$b), - OpcStr # "64", - [(set i64:$dst, (OpNode i64:$a, (i32 imm:$b)))]>; - def i32rr : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode i32:$a, i32:$b))]>; - def i32ri : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, i32imm:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode i32:$a, (i32 imm:$b)))]>; - def i32ii : - BasicNVPTXInst<(outs B32:$dst), (ins i32imm:$a, i32imm:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode (i32 imm:$a), (i32 imm:$b)))]>; - def i16rr : - BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B32:$b), - OpcStr # "16", - [(set i16:$dst, (OpNode i16:$a, i32:$b))]>; - def i16ri : - BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, i32imm:$b), - OpcStr # "16", - [(set i16:$dst, (OpNode i16:$a, (i32 imm:$b)))]>; + let hasSideEffects = false in { + foreach t = [I64RT, I32RT, I16RT] in { + def t.Size # _rr : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, B32:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode t.Ty:$a, i32:$b))]>; + def t.Size # _ri : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode t.Ty:$a, (i32 imm:$b)))]>; + def t.Size # _ii : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode (t.Ty imm:$a), (i32 imm:$b)))]>; + } + } } defm SHL : SHIFT<"shl.b", shl>; @@ -1301,14 +1181,11 @@ defm SRA : SHIFT<"shr.s", sra>; defm SRL : SHIFT<"shr.u", srl>; // Bit-reverse -def BREV32 : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a), - "brev.b32", - [(set i32:$dst, (bitreverse i32:$a))]>; -def BREV64 : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a), - "brev.b64", - [(set i64:$dst, (bitreverse i64:$a))]>; +foreach t = [I64RT, I32RT] in + def BREV_ # t.PtxType : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a), + "brev." # t.PtxType, + [(set t.Ty:$dst, (bitreverse t.Ty:$a))]>; // @@ -1562,10 +1439,7 @@ def SETP_bf16x2rr : def addr : ComplexPattern<pAny, 2, "SelectADDR">; -def ADDR_base : Operand<pAny> { - let PrintMethod = "printOperand"; -} - +def ADDR_base : Operand<pAny>; def ADDR : Operand<pAny> { let PrintMethod = "printMemOperand"; let MIOperandInfo = (ops ADDR_base, i32imm); @@ -1579,10 +1453,6 @@ def MmaCode : Operand<i32> { let PrintMethod = "printMmaCode"; } -def Offseti32imm : Operand<i32> { - let PrintMethod = "printOffseti32imm"; -} - // Get pointer to local stack. let hasSideEffects = false in { def MOV_DEPOT_ADDR : NVPTXInst<(outs B32:$d), (ins i32imm:$num), @@ -1594,33 +1464,31 @@ let hasSideEffects = false in { // copyPhysreg is hard-coded in NVPTXInstrInfo.cpp let hasSideEffects = false, isAsCheapAsAMove = true in { - // Class for register-to-register moves - class MOVr<RegisterClass RC, string OpStr> : - BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), - "mov." # OpStr>; - - // Class for immediate-to-register moves - class MOVi<RegisterClass RC, string OpStr, ValueType VT, Operand IMMType, SDNode ImmNode> : - BasicNVPTXInst<(outs RC:$dst), (ins IMMType:$src), - "mov." # OpStr, - [(set VT:$dst, ImmNode:$src)]>; -} + let isMoveReg = true in + class MOVr<RegisterClass RC, string OpStr> : + BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), "mov." # OpStr>; -def IMOV1r : MOVr<B1, "pred">; -def MOV16r : MOVr<B16, "b16">; -def IMOV32r : MOVr<B32, "b32">; -def IMOV64r : MOVr<B64, "b64">; -def IMOV128r : MOVr<B128, "b128">; + let isMoveImm = true in + class MOVi<RegTyInfo t, string suffix> : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.Imm:$src), + "mov." # suffix, + [(set t.Ty:$dst, t.ImmNode:$src)]>; +} +def MOV_B1_r : MOVr<B1, "pred">; +def MOV_B16_r : MOVr<B16, "b16">; +def MOV_B32_r : MOVr<B32, "b32">; +def MOV_B64_r : MOVr<B64, "b64">; +def MOV_B128_r : MOVr<B128, "b128">; -def IMOV1i : MOVi<B1, "pred", i1, i1imm, imm>; -def IMOV16i : MOVi<B16, "b16", i16, i16imm, imm>; -def IMOV32i : MOVi<B32, "b32", i32, i32imm, imm>; -def IMOV64i : MOVi<B64, "b64", i64, i64imm, imm>; -def FMOV16i : MOVi<B16, "b16", f16, f16imm, fpimm>; -def BFMOV16i : MOVi<B16, "b16", bf16, bf16imm, fpimm>; -def FMOV32i : MOVi<B32, "b32", f32, f32imm, fpimm>; -def FMOV64i : MOVi<B64, "b64", f64, f64imm, fpimm>; +def MOV_B1_i : MOVi<I1RT, "pred">; +def MOV_B16_i : MOVi<I16RT, "b16">; +def MOV_B32_i : MOVi<I32RT, "b32">; +def MOV_B64_i : MOVi<I64RT, "b64">; +def MOV_F16_i : MOVi<F16RT, "b16">; +def MOV_BF16_i : MOVi<BF16RT, "b16">; +def MOV_F32_i : MOVi<F32RT, "b32">; +def MOV_F64_i : MOVi<F64RT, "b64">; def to_tglobaladdr : SDNodeXForm<globaladdr, [{ @@ -1638,11 +1506,11 @@ def to_tframeindex : SDNodeXForm<frameindex, [{ return CurDAG->getTargetFrameIndex(N->getIndex(), N->getValueType(0)); }]>; -def : Pat<(i32 globaladdr:$dst), (IMOV32i (to_tglobaladdr $dst))>; -def : Pat<(i64 globaladdr:$dst), (IMOV64i (to_tglobaladdr $dst))>; +def : Pat<(i32 globaladdr:$dst), (MOV_B32_i (to_tglobaladdr $dst))>; +def : Pat<(i64 globaladdr:$dst), (MOV_B64_i (to_tglobaladdr $dst))>; -def : Pat<(i32 externalsym:$dst), (IMOV32i (to_texternsym $dst))>; -def : Pat<(i64 externalsym:$dst), (IMOV64i (to_texternsym $dst))>; +def : Pat<(i32 externalsym:$dst), (MOV_B32_i (to_texternsym $dst))>; +def : Pat<(i64 externalsym:$dst), (MOV_B64_i (to_texternsym $dst))>; //---- Copy Frame Index ---- def LEA_ADDRi : NVPTXInst<(outs B32:$dst), (ins ADDR:$addr), @@ -1831,7 +1699,6 @@ class LD<NVPTXRegClass regclass> "\t$dst, [$addr];", []>; let mayLoad=1, hasSideEffects=0 in { - def LD_i8 : LD<B16>; def LD_i16 : LD<B16>; def LD_i32 : LD<B32>; def LD_i64 : LD<B64>; @@ -1847,7 +1714,6 @@ class ST<DAGOperand O> " \t[$addr], $src;", []>; let mayStore=1, hasSideEffects=0 in { - def ST_i8 : ST<RI16>; def ST_i16 : ST<RI16>; def ST_i32 : ST<RI32>; def ST_i64 : ST<RI64>; @@ -1880,7 +1746,6 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> { "[$addr];", []>; } let mayLoad=1, hasSideEffects=0 in { - defm LDV_i8 : LD_VEC<B16>; defm LDV_i16 : LD_VEC<B16>; defm LDV_i32 : LD_VEC<B32, support_v8 = true>; defm LDV_i64 : LD_VEC<B64>; @@ -1914,7 +1779,6 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> { } let mayStore=1, hasSideEffects=0 in { - defm STV_i8 : ST_VEC<RI16>; defm STV_i16 : ST_VEC<RI16>; defm STV_i32 : ST_VEC<RI32, support_v8 = true>; defm STV_i64 : ST_VEC<RI64>; @@ -2084,14 +1948,14 @@ def : Pat<(i64 (anyext i32:$a)), (CVT_u64_u32 $a, CvtNONE)>; // truncate i64 def : Pat<(i32 (trunc i64:$a)), (CVT_u32_u64 $a, CvtNONE)>; def : Pat<(i16 (trunc i64:$a)), (CVT_u16_u64 $a, CvtNONE)>; -def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (ANDb64ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (AND_b64ri $a, 1), 0, CmpNE)>; // truncate i32 def : Pat<(i16 (trunc i32:$a)), (CVT_u16_u32 $a, CvtNONE)>; -def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (ANDb32ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (AND_b32ri $a, 1), 0, CmpNE)>; // truncate i16 -def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (ANDb16ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (AND_b16ri $a, 1), 0, CmpNE)>; // sext_inreg def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>; @@ -2335,32 +2199,20 @@ defm : CVT_ROUND<frint, CvtRNI, CvtRNI_FTZ>; //----------------------------------- let isTerminator=1 in { - let isReturn=1, isBarrier=1 in + let isReturn=1, isBarrier=1 in def Return : BasicNVPTXInst<(outs), (ins), "ret", [(retglue)]>; - let isBranch=1 in - def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), + let isBranch=1 in { + def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), "@$a bra \t$target;", [(brcond i1:$a, bb:$target)]>; - let isBranch=1 in - def CBranchOther : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), - "@!$a bra \t$target;", []>; - let isBranch=1, isBarrier=1 in + let isBarrier=1 in def GOTO : BasicNVPTXInst<(outs), (ins brtarget:$target), - "bra.uni", [(br bb:$target)]>; + "bra.uni", [(br bb:$target)]>; + } } -def : Pat<(brcond i32:$a, bb:$target), - (CBranch (SETP_i32ri $a, 0, CmpNE), bb:$target)>; - -// SelectionDAGBuilder::visitSWitchCase() will invert the condition of a -// conditional branch if the target block is the next block so that the code -// can fall through to the target block. The inversion is done by 'xor -// condition, 1', which will be translated to (setne condition, -1). Since ptx -// supports '@!pred bra target', we should use it. -def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target), - (CBranchOther $a, bb:$target)>; // trap instruction def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 0a00220..d337192 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -243,63 +243,82 @@ foreach sync = [false, true] in { } // vote.{all,any,uni,ballot} -multiclass VOTE<NVPTXRegClass regclass, string mode, Intrinsic IntOp> { - def : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred), - "vote." # mode, - [(set regclass:$dest, (IntOp i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; -} +let Predicates = [hasPTX<60>, hasSM<30>] in { + multiclass VOTE<string mode, RegTyInfo t, Intrinsic op> { + def : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred), + "vote." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op i1:$pred))]>; + } -defm VOTE_ALL : VOTE<B1, "all.pred", int_nvvm_vote_all>; -defm VOTE_ANY : VOTE<B1, "any.pred", int_nvvm_vote_any>; -defm VOTE_UNI : VOTE<B1, "uni.pred", int_nvvm_vote_uni>; -defm VOTE_BALLOT : VOTE<B32, "ballot.b32", int_nvvm_vote_ballot>; + defm VOTE_ALL : VOTE<"all", I1RT, int_nvvm_vote_all>; + defm VOTE_ANY : VOTE<"any", I1RT, int_nvvm_vote_any>; + defm VOTE_UNI : VOTE<"uni", I1RT, int_nvvm_vote_uni>; + defm VOTE_BALLOT : VOTE<"ballot", I32RT, int_nvvm_vote_ballot>; + + // vote.sync.{all,any,uni,ballot} + multiclass VOTE_SYNC<string mode, RegTyInfo t, Intrinsic op> { + def i : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, i32imm:$mask), + "vote.sync." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op imm:$mask, i1:$pred))]>; + def r : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, B32:$mask), + "vote.sync." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op i32:$mask, i1:$pred))]>; + } -// vote.sync.{all,any,uni,ballot} -multiclass VOTE_SYNC<NVPTXRegClass regclass, string mode, Intrinsic IntOp> { - def i : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, i32imm:$mask), - "vote.sync." # mode, - [(set regclass:$dest, (IntOp imm:$mask, i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; - def r : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, B32:$mask), - "vote.sync." # mode, - [(set regclass:$dest, (IntOp i32:$mask, i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; + defm VOTE_SYNC_ALL : VOTE_SYNC<"all", I1RT, int_nvvm_vote_all_sync>; + defm VOTE_SYNC_ANY : VOTE_SYNC<"any", I1RT, int_nvvm_vote_any_sync>; + defm VOTE_SYNC_UNI : VOTE_SYNC<"uni", I1RT, int_nvvm_vote_uni_sync>; + defm VOTE_SYNC_BALLOT : VOTE_SYNC<"ballot", I32RT, int_nvvm_vote_ballot_sync>; } - -defm VOTE_SYNC_ALL : VOTE_SYNC<B1, "all.pred", int_nvvm_vote_all_sync>; -defm VOTE_SYNC_ANY : VOTE_SYNC<B1, "any.pred", int_nvvm_vote_any_sync>; -defm VOTE_SYNC_UNI : VOTE_SYNC<B1, "uni.pred", int_nvvm_vote_uni_sync>; -defm VOTE_SYNC_BALLOT : VOTE_SYNC<B32, "ballot.b32", int_nvvm_vote_ballot_sync>; - // elect.sync +let Predicates = [hasPTX<80>, hasSM<90>] in { def INT_ELECT_SYNC_I : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins i32imm:$mask), "elect.sync", - [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>, - Requires<[hasPTX<80>, hasSM<90>]>; + [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>; def INT_ELECT_SYNC_R : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins B32:$mask), "elect.sync", - [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>, - Requires<[hasPTX<80>, hasSM<90>]>; + [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>; +} + +let Predicates = [hasPTX<60>, hasSM<70>] in { + multiclass MATCH_ANY_SYNC<Intrinsic op, RegTyInfo t> { + def ii : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, i32imm:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op imm:$mask, imm:$value))]>; + def ir : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, B32:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op i32:$mask, imm:$value))]>; + def ri : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, i32imm:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op imm:$mask, t.Ty:$value))]>; + def rr : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, B32:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op i32:$mask, t.Ty:$value))]>; + } -multiclass MATCH_ANY_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp, - Operand ImmOp> { - def ii : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, i32imm:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp imm:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ir : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, B32:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp i32:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ri : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, i32imm:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp imm:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def rr : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, B32:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp i32:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; + defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i32, I32RT>; + defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i64, I64RT>; + + multiclass MATCH_ALLP_SYNC<RegTyInfo t, Intrinsic op> { + def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.Imm:$value, i32imm:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op imm:$mask, imm:$value))]>; + def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.Imm:$value, B32:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op i32:$mask, imm:$value))]>; + def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.RC:$value, i32imm:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op imm:$mask, t.Ty:$value))]>; + def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.RC:$value, B32:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op i32:$mask, t.Ty:$value))]>; + } + defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<I32RT, int_nvvm_match_all_sync_i32p>; + defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<I64RT, int_nvvm_match_all_sync_i64p>; } // activemask.b32 @@ -308,39 +327,6 @@ def ACTIVEMASK : BasicNVPTXInst<(outs B32:$dest), (ins), [(set i32:$dest, (int_nvvm_activemask))]>, Requires<[hasPTX<62>, hasSM<30>]>; -defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<B32, "b32", int_nvvm_match_any_sync_i32, - i32imm>; -defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<B64, "b64", int_nvvm_match_any_sync_i64, - i64imm>; - -multiclass MATCH_ALLP_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp, - Operand ImmOp> { - def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins ImmOp:$value, i32imm:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp imm:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins ImmOp:$value, B32:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp i32:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins regclass:$value, i32imm:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp imm:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins regclass:$value, B32:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp i32:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; -} -defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<B32, "b32", int_nvvm_match_all_sync_i32p, - i32imm>; -defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<B64, "b64", int_nvvm_match_all_sync_i64p, - i64imm>; - multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> { def : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src, B32:$mask), "redux.sync." # BinOp # "." # PTXType, @@ -381,24 +367,20 @@ defm REDUX_SYNC_FMAX_ABS_NAN: REDUX_SYNC_F<"max", ".abs", ".NaN">; //----------------------------------- // Explicit Memory Fence Functions //----------------------------------- -class MEMBAR<string StrOp, Intrinsic IntOP> : - BasicNVPTXInst<(outs), (ins), - StrOp, [(IntOP)]>; +class NullaryInst<string StrOp, Intrinsic IntOP> : + BasicNVPTXInst<(outs), (ins), StrOp, [(IntOP)]>; -def INT_MEMBAR_CTA : MEMBAR<"membar.cta", int_nvvm_membar_cta>; -def INT_MEMBAR_GL : MEMBAR<"membar.gl", int_nvvm_membar_gl>; -def INT_MEMBAR_SYS : MEMBAR<"membar.sys", int_nvvm_membar_sys>; +def INT_MEMBAR_CTA : NullaryInst<"membar.cta", int_nvvm_membar_cta>; +def INT_MEMBAR_GL : NullaryInst<"membar.gl", int_nvvm_membar_gl>; +def INT_MEMBAR_SYS : NullaryInst<"membar.sys", int_nvvm_membar_sys>; def INT_FENCE_SC_CLUSTER: - MEMBAR<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, + NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, Requires<[hasPTX<78>, hasSM<90>]>; // Proxy fence (uni-directional) -// fence.proxy.tensormap.release variants - class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE<string Scope, Intrinsic Intr> : - BasicNVPTXInst<(outs), (ins), - "fence.proxy.tensormap::generic.release." # Scope, [(Intr)]>, + NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>, Requires<[hasPTX<83>, hasSM<90>]>; def INT_FENCE_PROXY_TENSORMAP_GENERIC_RELEASE_CTA: @@ -488,35 +470,31 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 : CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16, int_nvvm_cp_async_cg_shared_global_16_s>; -def CP_ASYNC_COMMIT_GROUP : - BasicNVPTXInst<(outs), (ins), "cp.async.commit_group", [(int_nvvm_cp_async_commit_group)]>, - Requires<[hasPTX<70>, hasSM<80>]>; +let Predicates = [hasPTX<70>, hasSM<80>] in { + def CP_ASYNC_COMMIT_GROUP : + NullaryInst<"cp.async.commit_group", int_nvvm_cp_async_commit_group>; -def CP_ASYNC_WAIT_GROUP : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group", - [(int_nvvm_cp_async_wait_group timm:$n)]>, - Requires<[hasPTX<70>, hasSM<80>]>; + def CP_ASYNC_WAIT_GROUP : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group", + [(int_nvvm_cp_async_wait_group timm:$n)]>; -def CP_ASYNC_WAIT_ALL : - BasicNVPTXInst<(outs), (ins), "cp.async.wait_all", - [(int_nvvm_cp_async_wait_all)]>, - Requires<[hasPTX<70>, hasSM<80>]>; + def CP_ASYNC_WAIT_ALL : + NullaryInst<"cp.async.wait_all", int_nvvm_cp_async_wait_all>; +} -// cp.async.bulk variants of the commit/wait group -def CP_ASYNC_BULK_COMMIT_GROUP : - BasicNVPTXInst<(outs), (ins), "cp.async.bulk.commit_group", - [(int_nvvm_cp_async_bulk_commit_group)]>, - Requires<[hasPTX<80>, hasSM<90>]>; +let Predicates = [hasPTX<80>, hasSM<90>] in { + // cp.async.bulk variants of the commit/wait group + def CP_ASYNC_BULK_COMMIT_GROUP : + NullaryInst<"cp.async.bulk.commit_group", int_nvvm_cp_async_bulk_commit_group>; -def CP_ASYNC_BULK_WAIT_GROUP : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group", - [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def CP_ASYNC_BULK_WAIT_GROUP : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group", + [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>; -def CP_ASYNC_BULK_WAIT_GROUP_READ : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read", - [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def CP_ASYNC_BULK_WAIT_GROUP_READ : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read", + [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>; +} //------------------------------ // TMA Async Bulk Copy Functions @@ -974,33 +952,30 @@ defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4", //Prefetch and Prefetchu -class PREFETCH_INTRS<string InstName> : - BasicNVPTXInst<(outs), (ins ADDR:$addr), - InstName, - [(!cast<Intrinsic>(!strconcat("int_nvvm_", - !subst(".", "_", InstName))) addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; - +let Predicates = [hasPTX<80>, hasSM<90>] in { + class PREFETCH_INTRS<string InstName> : + BasicNVPTXInst<(outs), (ins ADDR:$addr), + InstName, + [(!cast<Intrinsic>(!strconcat("int_nvvm_", + !subst(".", "_", InstName))) addr:$addr)]>; -def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">; -def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">; -def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">; -def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">; -def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">; -def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">; + def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">; + def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">; + def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">; + def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">; + def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">; + def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">; -def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "prefetch.global.L2::evict_normal", - [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr), + "prefetch.global.L2::evict_normal", + [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>; -def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "prefetch.global.L2::evict_last", - [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr), + "prefetch.global.L2::evict_last", + [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>; - -def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">; + def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">; +} //Applypriority intrinsics class APPLYPRIORITY_L2_INTRS<string addrspace> : @@ -1031,99 +1006,82 @@ def DISCARD_GLOBAL_L2 : DISCARD_L2_INTRS<"global">; // MBarrier Functions //----------------------------------- -multiclass MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count), - "mbarrier.init" # AddrSpace # ".b64", - [(Intrin addr:$addr, i32:$count)]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; -defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", - int_nvvm_mbarrier_init_shared>; - -multiclass MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "mbarrier.inval" # AddrSpace # ".b64", - [(Intrin addr:$addr)]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; -defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", - int_nvvm_mbarrier_inval_shared>; - -multiclass MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), - "mbarrier.arrive" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; -defm MBARRIER_ARRIVE_SHARED : - MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; - -multiclass MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), - (ins ADDR:$addr, B32:$count), - "mbarrier.arrive.noComplete" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr, i32:$count))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_NOCOMPLETE : - MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; -defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED : - MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; - -multiclass MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), - "mbarrier.arrive_drop" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_DROP : - MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; -defm MBARRIER_ARRIVE_DROP_SHARED : - MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; - -multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), - (ins ADDR:$addr, B32:$count), - "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr, i32:$count))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_DROP_NOCOMPLETE : - MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; -defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : - MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", - int_nvvm_mbarrier_arrive_drop_noComplete_shared>; - -multiclass MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state), - "mbarrier.test_wait" # AddrSpace # ".b64", - [(set i1:$res, (Intrin addr:$addr, i64:$state))]>, - Requires<[hasPTX<70>, hasSM<80>]>; +let Predicates = [hasPTX<70>, hasSM<80>] in { + class MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count), + "mbarrier.init" # AddrSpace # ".b64", + [(Intrin addr:$addr, i32:$count)]>; + + def MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; + def MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", + int_nvvm_mbarrier_init_shared>; + + class MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs), (ins ADDR:$addr), + "mbarrier.inval" # AddrSpace # ".b64", + [(Intrin addr:$addr)]>; + + def MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; + def MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", + int_nvvm_mbarrier_inval_shared>; + + class MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), + "mbarrier.arrive" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr))]>; + + def MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; + def MBARRIER_ARRIVE_SHARED : + MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; + + class MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), + (ins ADDR:$addr, B32:$count), + "mbarrier.arrive.noComplete" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr, i32:$count))]>; + + def MBARRIER_ARRIVE_NOCOMPLETE : + MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; + def MBARRIER_ARRIVE_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; + + class MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), + "mbarrier.arrive_drop" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr))]>; + + def MBARRIER_ARRIVE_DROP : + MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; + def MBARRIER_ARRIVE_DROP_SHARED : + MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; + + class MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), + (ins ADDR:$addr, B32:$count), + "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr, i32:$count))]>; + + def MBARRIER_ARRIVE_DROP_NOCOMPLETE : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; + def MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", + int_nvvm_mbarrier_arrive_drop_noComplete_shared>; + + class MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state), + "mbarrier.test_wait" # AddrSpace # ".b64", + [(set i1:$res, (Intrin addr:$addr, i64:$state))]>; + + def MBARRIER_TEST_WAIT : + MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; + def MBARRIER_TEST_WAIT_SHARED : + MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; + + def MBARRIER_PENDING_COUNT : + BasicNVPTXInst<(outs B32:$res), (ins B64:$state), + "mbarrier.pending_count.b64", + [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>; } - -defm MBARRIER_TEST_WAIT : - MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; -defm MBARRIER_TEST_WAIT_SHARED : - MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; - -class MBARRIER_PENDING_COUNT<Intrinsic Intrin> : - BasicNVPTXInst<(outs B32:$res), (ins B64:$state), - "mbarrier.pending_count.b64", - [(set i32:$res, (Intrin i64:$state))]>, - Requires<[hasPTX<70>, hasSM<80>]>; - -def MBARRIER_PENDING_COUNT : - MBARRIER_PENDING_COUNT<int_nvvm_mbarrier_pending_count>; - //----------------------------------- // Math Functions //----------------------------------- @@ -1449,15 +1407,11 @@ defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>; def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>; -def COPYSIGN_F : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$src0, B32:$src1), - "copysign.f32", - [(set f32:$dst, (fcopysign_nvptx f32:$src1, f32:$src0))]>; - -def COPYSIGN_D : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$src0, B64:$src1), - "copysign.f64", - [(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>; +foreach t = [F32RT, F64RT] in + def COPYSIGN_ # t : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src0, t.RC:$src1), + "copysign." # t.PtxType, + [(set t.Ty:$dst, (fcopysign_nvptx t.Ty:$src1, t.Ty:$src0))]>; // // Neg bf16, bf16x2 @@ -2255,38 +2209,35 @@ defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">; // Scalar -class LDU_G<string TyStr, NVPTXRegClass regclass> - : NVPTXInst<(outs regclass:$result), (ins ADDR:$src), - "ldu.global." # TyStr # " \t$result, [$src];", []>; +class LDU_G<NVPTXRegClass regclass> + : NVPTXInst<(outs regclass:$result), (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.b$fromWidth \t$result, [$src];", []>; -def LDU_GLOBAL_i8 : LDU_G<"b8", B16>; -def LDU_GLOBAL_i16 : LDU_G<"b16", B16>; -def LDU_GLOBAL_i32 : LDU_G<"b32", B32>; -def LDU_GLOBAL_i64 : LDU_G<"b64", B64>; +def LDU_GLOBAL_i16 : LDU_G<B16>; +def LDU_GLOBAL_i32 : LDU_G<B32>; +def LDU_GLOBAL_i64 : LDU_G<B64>; // vector // Elementized vector ldu -class VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> +class VLDU_G_ELE_V2<NVPTXRegClass regclass> : NVPTXInst<(outs regclass:$dst1, regclass:$dst2), - (ins ADDR:$src), - "ldu.global.v2." # TyStr # " \t{{$dst1, $dst2}}, [$src];", []>; + (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.v2.b$fromWidth \t{{$dst1, $dst2}}, [$src];", []>; -class VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> - : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins ADDR:$src), - "ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>; +class VLDU_G_ELE_V4<NVPTXRegClass regclass> + : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), + (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.v4.b$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>; -def LDU_GLOBAL_v2i8 : VLDU_G_ELE_V2<"b8", B16>; -def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<"b16", B16>; -def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<"b32", B32>; -def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<"b64", B64>; +def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<B16>; +def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<B32>; +def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<B64>; -def LDU_GLOBAL_v4i8 : VLDU_G_ELE_V4<"b8", B16>; -def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<"b16", B16>; -def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<"b32", B32>; +def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<B16>; +def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>; //----------------------------------- @@ -2327,12 +2278,10 @@ class VLDG_G_ELE_V8<NVPTXRegClass regclass> : "ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>; // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads. -def LD_GLOBAL_NC_v2i8 : VLDG_G_ELE_V2<B16>; def LD_GLOBAL_NC_v2i16 : VLDG_G_ELE_V2<B16>; def LD_GLOBAL_NC_v2i32 : VLDG_G_ELE_V2<B32>; def LD_GLOBAL_NC_v2i64 : VLDG_G_ELE_V2<B64>; -def LD_GLOBAL_NC_v4i8 : VLDG_G_ELE_V4<B16>; def LD_GLOBAL_NC_v4i16 : VLDG_G_ELE_V4<B16>; def LD_GLOBAL_NC_v4i32 : VLDG_G_ELE_V4<B32>; @@ -2342,19 +2291,19 @@ def LD_GLOBAL_NC_v8i32 : VLDG_G_ELE_V8<B32>; multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> { if Supports32 then def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src), - "cvta." # Str # ".u32", []>, Requires<Preds>; + "cvta." # Str # ".u32">, Requires<Preds>; def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src), - "cvta." # Str # ".u64", []>, Requires<Preds>; + "cvta." # Str # ".u64">, Requires<Preds>; } multiclass G_TO_NG<string Str, bit Supports32 = 1, list<Predicate> Preds = []> { if Supports32 then def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src), - "cvta.to." # Str # ".u32", []>, Requires<Preds>; + "cvta.to." # Str # ".u32">, Requires<Preds>; def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src), - "cvta.to." # Str # ".u64", []>, Requires<Preds>; + "cvta.to." # Str # ".u64">, Requires<Preds>; } foreach space = ["local", "shared", "global", "const", "param"] in { @@ -4614,9 +4563,9 @@ def INT_PTX_SREG_LANEMASK_GT : PTX_READ_SREG_R32<"lanemask_gt", int_nvvm_read_ptx_sreg_lanemask_gt>; let hasSideEffects = 1 in { -def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; -def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; -def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; + def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; + def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; + def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; } def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>; @@ -5096,37 +5045,36 @@ foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in def : MMA_PAT<mma>; multiclass MAPA<string suffix, Intrinsic Intr> { - def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b), - "mapa" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a, i32:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b), - "mapa" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a, imm:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b), - "mapa" # suffix # ".u64", - [(set i64:$d, (Intr i64:$a, i32:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b), - "mapa" # suffix # ".u64", - [(set i64:$d, (Intr i64:$a, imm:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; + let Predicates = [hasSM<90>, hasPTX<78>] in { + def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b), + "mapa" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a, i32:$b))]>; + def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b), + "mapa" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a, imm:$b))]>; + def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b), + "mapa" # suffix # ".u64", + [(set i64:$d, (Intr i64:$a, i32:$b))]>; + def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b), + "mapa" # suffix # ".u64", + [(set i64:$d, (Intr i64:$a, imm:$b))]>; + } } + defm mapa : MAPA<"", int_nvvm_mapa>; defm mapa_shared_cluster : MAPA<".shared::cluster", int_nvvm_mapa_shared_cluster>; multiclass GETCTARANK<string suffix, Intrinsic Intr> { - def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a), - "getctarank" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a), - "getctarank" # suffix # ".u64", - [(set i32:$d, (Intr i64:$a))]>, - Requires<[hasSM<90>, hasPTX<78>]>; + let Predicates = [hasSM<90>, hasPTX<78>] in { + def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a), + "getctarank" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a))]>; + def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a), + "getctarank" # suffix # ".u64", + [(set i32:$d, (Intr i64:$a))]>; + } } defm getctarank : GETCTARANK<"", int_nvvm_getctarank>; @@ -5165,29 +5113,25 @@ def INT_NVVM_WGMMA_WAIT_GROUP_SYNC_ALIGNED : BasicNVPTXInst<(outs), (ins i64imm: [(int_nvvm_wgmma_wait_group_sync_aligned timm:$n)]>, Requires<[hasSM90a, hasPTX<80>]>; } // isConvergent = true -def GRIDDEPCONTROL_LAUNCH_DEPENDENTS : - BasicNVPTXInst<(outs), (ins), - "griddepcontrol.launch_dependents", - [(int_nvvm_griddepcontrol_launch_dependents)]>, - Requires<[hasSM<90>, hasPTX<78>]>; - -def GRIDDEPCONTROL_WAIT : - BasicNVPTXInst<(outs), (ins), - "griddepcontrol.wait", - [(int_nvvm_griddepcontrol_wait)]>, - Requires<[hasSM<90>, hasPTX<78>]>; +let Predicates = [hasSM<90>, hasPTX<78>] in { + def GRIDDEPCONTROL_LAUNCH_DEPENDENTS : + BasicNVPTXInst<(outs), (ins), "griddepcontrol.launch_dependents", + [(int_nvvm_griddepcontrol_launch_dependents)]>; + def GRIDDEPCONTROL_WAIT : + BasicNVPTXInst<(outs), (ins), "griddepcontrol.wait", + [(int_nvvm_griddepcontrol_wait)]>; +} def INT_EXIT : BasicNVPTXInst<(outs), (ins), "exit", [(int_nvvm_exit)]>; // Tcgen05 intrinsics -let isConvergent = true in { +let isConvergent = true, Predicates = [hasTcgen05Instructions] in { multiclass TCGEN05_ALLOC_INTR<string AS, string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$dst, B32:$ncols), "tcgen05.alloc.cta_group::" # num # ".sync.aligned" # AS # ".b32", - [(Intr addr:$dst, B32:$ncols)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$dst, B32:$ncols)]>; } defm TCGEN05_ALLOC_CG1 : TCGEN05_ALLOC_INTR<"", "1", int_nvvm_tcgen05_alloc_cg1>; @@ -5200,8 +5144,7 @@ multiclass TCGEN05_DEALLOC_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins B32:$tmem_addr, B32:$ncols), "tcgen05.dealloc.cta_group::" # num # ".sync.aligned.b32", - [(Intr B32:$tmem_addr, B32:$ncols)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr B32:$tmem_addr, B32:$ncols)]>; } defm TCGEN05_DEALLOC_CG1: TCGEN05_DEALLOC_INTR<"1", int_nvvm_tcgen05_dealloc_cg1>; defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2>; @@ -5209,19 +5152,13 @@ defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2 multiclass TCGEN05_RELINQ_PERMIT_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins), "tcgen05.relinquish_alloc_permit.cta_group::" # num # ".sync.aligned", - [(Intr)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr)]>; } defm TCGEN05_RELINQ_CG1: TCGEN05_RELINQ_PERMIT_INTR<"1", int_nvvm_tcgen05_relinq_alloc_permit_cg1>; defm TCGEN05_RELINQ_CG2: TCGEN05_RELINQ_PERMIT_INTR<"2", int_nvvm_tcgen05_relinq_alloc_permit_cg2>; -def tcgen05_wait_ld: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::ld.sync.aligned", - [(int_nvvm_tcgen05_wait_ld)]>, - Requires<[hasTcgen05Instructions]>; - -def tcgen05_wait_st: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::st.sync.aligned", - [(int_nvvm_tcgen05_wait_st)]>, - Requires<[hasTcgen05Instructions]>; +def tcgen05_wait_ld: NullaryInst<"tcgen05.wait::ld.sync.aligned", int_nvvm_tcgen05_wait_ld>; +def tcgen05_wait_st: NullaryInst<"tcgen05.wait::st.sync.aligned", int_nvvm_tcgen05_wait_st>; multiclass TCGEN05_COMMIT_INTR<string AS, string num> { defvar prefix = "tcgen05.commit.cta_group::" # num #".mbarrier::arrive::one.shared::cluster"; @@ -5232,12 +5169,10 @@ multiclass TCGEN05_COMMIT_INTR<string AS, string num> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$mbar), prefix # ".b64", - [(Intr addr:$mbar)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$mbar)]>; def _MC : BasicNVPTXInst<(outs), (ins ADDR:$mbar, B16:$mc), prefix # ".multicast::cluster.b64", - [(IntrMC addr:$mbar, B16:$mc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrMC addr:$mbar, B16:$mc)]>; } defm TCGEN05_COMMIT_CG1 : TCGEN05_COMMIT_INTR<"", "1">; @@ -5249,8 +5184,7 @@ multiclass TCGEN05_SHIFT_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr), "tcgen05.shift.cta_group::" # num # ".down", - [(Intr addr:$tmem_addr)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$tmem_addr)]>; } defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>; defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>; @@ -5270,13 +5204,11 @@ multiclass TCGEN05_CP_INTR<string shape, string src_fmt, string mc = ""> { def _cg1 : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr, B64:$sdesc), "tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm, - [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>; def _cg2 : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr, B64:$sdesc), "tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm, - [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>; } foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in { @@ -5289,17 +5221,13 @@ foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in { } } // isConvergent -let hasSideEffects = 1 in { +let hasSideEffects = 1, Predicates = [hasTcgen05Instructions] in { -def tcgen05_fence_before_thread_sync: BasicNVPTXInst<(outs), (ins), - "tcgen05.fence::before_thread_sync", - [(int_nvvm_tcgen05_fence_before_thread_sync)]>, - Requires<[hasTcgen05Instructions]>; + def tcgen05_fence_before_thread_sync: NullaryInst< + "tcgen05.fence::before_thread_sync", int_nvvm_tcgen05_fence_before_thread_sync>; -def tcgen05_fence_after_thread_sync: BasicNVPTXInst<(outs), (ins), - "tcgen05.fence::after_thread_sync", - [(int_nvvm_tcgen05_fence_after_thread_sync)]>, - Requires<[hasTcgen05Instructions]>; + def tcgen05_fence_after_thread_sync: NullaryInst< + "tcgen05.fence::after_thread_sync", int_nvvm_tcgen05_fence_after_thread_sync>; } // hasSideEffects @@ -5392,17 +5320,17 @@ foreach shape = ["16x64b", "16x128b", "16x256b", "32x32b", "16x32bx2"] in { // Bulk store instructions def st_bulk_imm : TImmLeaf<i64, [{ return Imm == 0; }]>; -def INT_NVVM_ST_BULK_GENERIC : - BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), - "st.bulk", - [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>, - Requires<[hasSM<100>, hasPTX<86>]>; +let Predicates = [hasSM<100>, hasPTX<86>] in { + def INT_NVVM_ST_BULK_GENERIC : + BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), + "st.bulk", + [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>; -def INT_NVVM_ST_BULK_SHARED_CTA: - BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), - "st.bulk.shared::cta", - [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>, - Requires<[hasSM<100>, hasPTX<86>]>; + def INT_NVVM_ST_BULK_SHARED_CTA: + BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), + "st.bulk.shared::cta", + [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>; +} // // clusterlaunchcontorl Instructions diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td index d40886a..2e81ab1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -38,14 +38,6 @@ foreach i = 0...4 in { def R#i : NVPTXReg<"%r"#i>; // 32-bit def RL#i : NVPTXReg<"%rd"#i>; // 64-bit def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit - def H#i : NVPTXReg<"%h"#i>; // 16-bit float - def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float - - // Arguments - def ia#i : NVPTXReg<"%ia"#i>; - def la#i : NVPTXReg<"%la"#i>; - def fa#i : NVPTXReg<"%fa"#i>; - def da#i : NVPTXReg<"%da"#i>; } foreach i = 0...31 in { diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp index 459525e..f179873 100644 --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -7296,9 +7296,17 @@ SDValue PPCTargetLowering::LowerFormalArguments_AIX( if (!ArgVT.isVector() && !ValVT.isVector() && ArgVT.isInteger() && ValVT.isInteger() && ArgVT.getScalarSizeInBits() < ValVT.getScalarSizeInBits()) { - SDValue ArgValueTrunc = DAG.getNode( - ISD::TRUNCATE, dl, ArgVT.getSimpleVT() == MVT::i1 ? MVT::i8 : ArgVT, - ArgValue); + // It is possible to have either real integer values + // or integers that were not originally integers. + // In the latter case, these could have came from structs, + // and these integers would not have an extend on the parameter. + // Since these types of integers do not have an extend specified + // in the first place, the type of extend that we do should not matter. + EVT TruncatedArgVT = ArgVT.isSimple() && ArgVT.getSimpleVT() == MVT::i1 + ? MVT::i8 + : ArgVT; + SDValue ArgValueTrunc = + DAG.getNode(ISD::TRUNCATE, dl, TruncatedArgVT, ArgValue); SDValue ArgValueExt = ArgSignExt ? DAG.getSExtOrTrunc(ArgValueTrunc, dl, ValVT) : DAG.getZExtOrTrunc(ArgValueTrunc, dl, ValVT); diff --git a/llvm/lib/Target/PowerPC/PPCMachineScheduler.cpp b/llvm/lib/Target/PowerPC/PPCMachineScheduler.cpp index 5eb1f01..b7e2263 100644 --- a/llvm/lib/Target/PowerPC/PPCMachineScheduler.cpp +++ b/llvm/lib/Target/PowerPC/PPCMachineScheduler.cpp @@ -100,10 +100,14 @@ bool PPCPreRASchedStrategy::tryCandidate(SchedCandidate &Cand, // This is a best effort to set things up for a post-RA pass. Optimizations // like generating loads of multiple registers should ideally be done within // the scheduler pass by combining the loads during DAG postprocessing. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; @@ -189,10 +193,14 @@ bool PPCPostRASchedStrategy::tryCandidate(SchedCandidate &Cand, return TryCand.Reason != NoCand; // Keep clustered nodes together. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp index 82e3b5c..9538b20 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp @@ -901,7 +901,7 @@ void RISCVAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned Offset = Fixup.getOffset(); unsigned NumBytes = alignTo(Info.TargetSize + Info.TargetOffset, 8) / 8; - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index f223fdbe..5998653 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -2827,6 +2827,8 @@ static bool selectConstantAddr(SelectionDAG *CurDAG, const SDLoc &DL, static bool isWorthFoldingAdd(SDValue Add) { for (auto *User : Add->users()) { if (User->getOpcode() != ISD::LOAD && User->getOpcode() != ISD::STORE && + User->getOpcode() != RISCVISD::LD_RV32 && + User->getOpcode() != RISCVISD::SD_RV32 && User->getOpcode() != ISD::ATOMIC_LOAD && User->getOpcode() != ISD::ATOMIC_STORE) return false; @@ -2841,6 +2843,9 @@ static bool isWorthFoldingAdd(SDValue Add) { if (User->getOpcode() == ISD::ATOMIC_STORE && cast<AtomicSDNode>(User)->getVal() == Add) return false; + if (User->getOpcode() == RISCVISD::SD_RV32 && + (User->getOperand(0) == Add || User->getOperand(1) == Add)) + return false; if (isStrongerThanMonotonic(cast<MemSDNode>(User)->getSuccessOrdering())) return false; } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index c0ada51..adbfbeb 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1819,6 +1819,13 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, case Intrinsic::riscv_seg6_load_mask: case Intrinsic::riscv_seg7_load_mask: case Intrinsic::riscv_seg8_load_mask: + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: return SetRVVLoadStoreInfo(/*PtrOp*/ 0, /*IsStore*/ false, /*IsUnitStrided*/ false, /*UsePtrVal*/ true); case Intrinsic::riscv_seg2_store_mask: @@ -10938,6 +10945,97 @@ static inline SDValue getVCIXISDNodeVOID(SDValue &Op, SelectionDAG &DAG, return DAG.getNode(Type, SDLoc(Op), Op.getValueType(), Operands); } +static SDValue +lowerFixedVectorSegLoadIntrinsics(unsigned IntNo, SDValue Op, + const RISCVSubtarget &Subtarget, + SelectionDAG &DAG) { + bool IsStrided; + switch (IntNo) { + case Intrinsic::riscv_seg2_load_mask: + case Intrinsic::riscv_seg3_load_mask: + case Intrinsic::riscv_seg4_load_mask: + case Intrinsic::riscv_seg5_load_mask: + case Intrinsic::riscv_seg6_load_mask: + case Intrinsic::riscv_seg7_load_mask: + case Intrinsic::riscv_seg8_load_mask: + IsStrided = false; + break; + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: + IsStrided = true; + break; + default: + llvm_unreachable("unexpected intrinsic ID"); + }; + + static const Intrinsic::ID VlsegInts[7] = { + Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask, + Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask, + Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask, + Intrinsic::riscv_vlseg8_mask}; + static const Intrinsic::ID VlssegInts[7] = { + Intrinsic::riscv_vlsseg2_mask, Intrinsic::riscv_vlsseg3_mask, + Intrinsic::riscv_vlsseg4_mask, Intrinsic::riscv_vlsseg5_mask, + Intrinsic::riscv_vlsseg6_mask, Intrinsic::riscv_vlsseg7_mask, + Intrinsic::riscv_vlsseg8_mask}; + + SDLoc DL(Op); + unsigned NF = Op->getNumValues() - 1; + assert(NF >= 2 && NF <= 8 && "Unexpected seg number"); + MVT XLenVT = Subtarget.getXLenVT(); + MVT VT = Op->getSimpleValueType(0); + MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget); + unsigned Sz = NF * ContainerVT.getVectorMinNumElements() * + ContainerVT.getScalarSizeInBits(); + EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF); + + // Operands: (chain, int_id, pointer, mask, vl) or + // (chain, int_id, pointer, offset, mask, vl) + SDValue VL = Op.getOperand(Op.getNumOperands() - 1); + SDValue Mask = Op.getOperand(Op.getNumOperands() - 2); + MVT MaskVT = Mask.getSimpleValueType(); + MVT MaskContainerVT = + ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget); + Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget); + + SDValue IntID = DAG.getTargetConstant( + IsStrided ? VlssegInts[NF - 2] : VlsegInts[NF - 2], DL, XLenVT); + auto *Load = cast<MemIntrinsicSDNode>(Op); + + SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other}); + SmallVector<SDValue, 9> Ops = { + Load->getChain(), + IntID, + DAG.getUNDEF(VecTupTy), + Op.getOperand(2), + Mask, + VL, + DAG.getTargetConstant( + RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT), + DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)}; + // Insert the stride operand. + if (IsStrided) + Ops.insert(std::next(Ops.begin(), 4), Op.getOperand(3)); + + SDValue Result = + DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, + Load->getMemoryVT(), Load->getMemOperand()); + SmallVector<SDValue, 9> Results; + for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) { + SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT, + Result.getValue(0), + DAG.getTargetConstant(RetIdx, DL, MVT::i32)); + Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget)); + } + Results.push_back(Result.getValue(1)); + return DAG.getMergeValues(Results, DL); +} + SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = Op.getConstantOperandVal(1); @@ -10950,57 +11048,16 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, case Intrinsic::riscv_seg5_load_mask: case Intrinsic::riscv_seg6_load_mask: case Intrinsic::riscv_seg7_load_mask: - case Intrinsic::riscv_seg8_load_mask: { - SDLoc DL(Op); - static const Intrinsic::ID VlsegInts[7] = { - Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask, - Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask, - Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask, - Intrinsic::riscv_vlseg8_mask}; - unsigned NF = Op->getNumValues() - 1; - assert(NF >= 2 && NF <= 8 && "Unexpected seg number"); - MVT XLenVT = Subtarget.getXLenVT(); - MVT VT = Op->getSimpleValueType(0); - MVT ContainerVT = getContainerForFixedLengthVector(VT); - unsigned Sz = NF * ContainerVT.getVectorMinNumElements() * - ContainerVT.getScalarSizeInBits(); - EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF); - - // Operands: (chain, int_id, pointer, mask, vl) - SDValue VL = Op.getOperand(Op.getNumOperands() - 1); - SDValue Mask = Op.getOperand(3); - MVT MaskVT = Mask.getSimpleValueType(); - MVT MaskContainerVT = - ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget); - Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget); - - SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT); - auto *Load = cast<MemIntrinsicSDNode>(Op); + case Intrinsic::riscv_seg8_load_mask: + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: + return lowerFixedVectorSegLoadIntrinsics(IntNo, Op, Subtarget, DAG); - SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other}); - SDValue Ops[] = { - Load->getChain(), - IntID, - DAG.getUNDEF(VecTupTy), - Op.getOperand(2), - Mask, - VL, - DAG.getTargetConstant( - RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT), - DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)}; - SDValue Result = - DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, - Load->getMemoryVT(), Load->getMemOperand()); - SmallVector<SDValue, 9> Results; - for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) { - SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT, - Result.getValue(0), - DAG.getTargetConstant(RetIdx, DL, MVT::i32)); - Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget)); - } - Results.push_back(Result.getValue(1)); - return DAG.getMergeValues(Results, DL); - } case Intrinsic::riscv_sf_vc_v_x_se: return getVCIXISDNodeWCHAIN(Op, DAG, RISCVISD::SF_VC_V_X_SE); case Intrinsic::riscv_sf_vc_v_i_se: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index 31ea2de..cc2977c 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -910,7 +910,7 @@ foreach vti = AllIntegerVectors in { foreach vti = I64IntegerVectors in { let Predicates = [HasVInstructionsI64] in { def : Pat<(add (vti.Vector vti.RegClass:$rs1), - (vti.Vector (SplatPat_imm64_neg i64:$rs2))), + (vti.Vector (SplatPat_imm64_neg (i64 GPR:$rs2)))), (!cast<Instruction>("PseudoVSUB_VX_"#vti.LMul.MX) (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 695223b..acbccdd 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -2123,7 +2123,7 @@ foreach vti = AllIntegerVectors in { foreach vti = I64IntegerVectors in { let Predicates = [HasVInstructionsI64] in { def : Pat<(riscv_add_vl (vti.Vector vti.RegClass:$rs1), - (vti.Vector (SplatPat_imm64_neg i64:$rs2)), + (vti.Vector (SplatPat_imm64_neg (i64 GPR:$rs2))), vti.RegClass:$passthru, (vti.Mask VMV0:$vm), VLOpFrag), (!cast<Instruction>("PseudoVSUB_VX_"#vti.LMul.MX#"_MASK") vti.RegClass:$passthru, vti.RegClass:$rs1, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td index c0f7ab1..4c31ce4 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td @@ -590,12 +590,12 @@ let Predicates = [HasVendorXTHeadBb, IsRV64] in { def : PatGprImm<riscv_rorw, TH_SRRIW, uimm5>; def : Pat<(riscv_rolw GPR:$rs1, uimm5:$rs2), (TH_SRRIW GPR:$rs1, (ImmSubFrom32 uimm5:$rs2))>; -def : Pat<(sra (bswap i64:$rs1), (i64 32)), - (TH_REVW i64:$rs1)>; -def : Pat<(binop_allwusers<srl> (bswap i64:$rs1), (i64 32)), - (TH_REVW i64:$rs1)>; -def : Pat<(riscv_clzw i64:$rs1), - (TH_FF0 (i64 (SLLI (i64 (XORI i64:$rs1, -1)), 32)))>; +def : Pat<(i64 (sra (bswap GPR:$rs1), (i64 32))), + (TH_REVW GPR:$rs1)>; +def : Pat<(binop_allwusers<srl> (bswap GPR:$rs1), (i64 32)), + (TH_REVW GPR:$rs1)>; +def : Pat<(riscv_clzw GPR:$rs1), + (TH_FF0 (i64 (SLLI (i64 (XORI GPR:$rs1, -1)), 32)))>; } // Predicates = [HasVendorXTHeadBb, IsRV64] let Predicates = [HasVendorXTHeadBs] in { @@ -697,11 +697,13 @@ def uimm2_4 : Operand<XLenVT>, ImmLeaf<XLenVT, [{ }], uimm2_4_XFORM>; let Predicates = [HasVendorXTHeadMemPair, IsRV64] in { -def : Pat<(th_lwud i64:$rs1, uimm2_3:$uimm2_3), (TH_LWUD i64:$rs1, uimm2_3:$uimm2_3, 3)>; -def : Pat<(th_ldd i64:$rs1, uimm2_4:$uimm2_4), (TH_LDD i64:$rs1, uimm2_4:$uimm2_4, 4)>; +def : Pat<(th_lwud GPR:$rs1, (i64 uimm2_3:$uimm2_3)), + (TH_LWUD GPR:$rs1, uimm2_3:$uimm2_3, 3)>; +def : Pat<(th_ldd GPR:$rs1, (i64 uimm2_4:$uimm2_4)), + (TH_LDD GPR:$rs1, uimm2_4:$uimm2_4, 4)>; -def : Pat<(th_sdd i64:$rd1, i64:$rd2, i64:$rs1, uimm2_4:$uimm2_4), - (TH_SDD i64:$rd1, i64:$rd2, i64:$rs1, uimm2_4:$uimm2_4, 4)>; +def : Pat<(th_sdd (i64 GPR:$rd1), GPR:$rd2, GPR:$rs1, uimm2_4:$uimm2_4), + (TH_SDD GPR:$rd1, GPR:$rd2, GPR:$rs1, uimm2_4:$uimm2_4, 4)>; } let Predicates = [HasVendorXTHeadMemPair] in { diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 61dbd06..0d5eb86 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -2627,18 +2627,17 @@ void RISCVTTIImpl::getUnrollingPreferences( if (L->getNumBlocks() > 4) return; - // Don't unroll vectorized loops, including the remainder loop - if (getBooleanLoopAttribute(L, "llvm.loop.isvectorized")) - return; - // Scan the loop: don't unroll loops with calls as this could prevent - // inlining. + // inlining. Don't unroll auto-vectorized loops either, though do allow + // unrolling of the scalar remainder. + bool IsVectorized = getBooleanLoopAttribute(L, "llvm.loop.isvectorized"); InstructionCost Cost = 0; for (auto *BB : L->getBlocks()) { for (auto &I : *BB) { - // Initial setting - Don't unroll loops containing vectorized - // instructions. - if (I.getType()->isVectorTy()) + // Both auto-vectorized loops and the scalar remainder have the + // isvectorized attribute, so differentiate between them by the presence + // of vector instructions. + if (IsVectorized && I.getType()->isVectorTy()) return; if (isa<CallInst>(I) || isa<InvokeInst>(I)) { diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 3c631ce..947b574 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -194,6 +194,42 @@ class SPIRVEmitIntrinsics void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B); + // Tries to walk the type accessed by the given GEP instruction. + // For each nested type access, one of the 2 callbacks is called: + // - OnLiteralIndexing when the index is a known constant value. + // Parameters: + // PointedType: the pointed type resulting of this indexing. + // If the parent type is an array, this is the index in the array. + // If the parent type is a struct, this is the field index. + // Index: index of the element in the parent type. + // - OnDynamnicIndexing when the index is a non-constant value. + // This callback is only called when indexing into an array. + // Parameters: + // ElementType: the type of the elements stored in the parent array. + // Offset: the Value* containing the byte offset into the array. + // Return true if an error occured during the walk, false otherwise. + bool walkLogicalAccessChain( + GetElementPtrInst &GEP, + const std::function<void(Type *PointedType, uint64_t Index)> + &OnLiteralIndexing, + const std::function<void(Type *ElementType, Value *Offset)> + &OnDynamicIndexing); + + // Returns the type accessed using the given GEP instruction by relying + // on the GEP type. + // FIXME: GEP types are not supposed to be used to retrieve the pointed + // type. This must be fixed. + Type *getGEPType(GetElementPtrInst *GEP); + + // Returns the type accessed using the given GEP instruction by walking + // the source type using the GEP indices. + // FIXME: without help from the frontend, this method cannot reliably retrieve + // the stored type, nor can robustly determine the depth of the type + // we are accessing. + Type *getGEPTypeLogical(GetElementPtrInst *GEP); + + Instruction *buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP); + public: static char ID; SPIRVEmitIntrinsics(SPIRVTargetMachine *TM = nullptr) @@ -246,6 +282,17 @@ bool expectIgnoredInIRTranslation(const Instruction *I) { } } +// Returns the source pointer from `I` ignoring intermediate ptrcast. +Value *getPointerRoot(Value *I) { + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::spv_ptrcast) { + Value *V = II->getArgOperand(0); + return getPointerRoot(V); + } + } + return I; +} + } // namespace char SPIRVEmitIntrinsics::ID = 0; @@ -555,7 +602,112 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy, Ty = RefTy; } -Type *getGEPType(GetElementPtrInst *Ref) { +bool SPIRVEmitIntrinsics::walkLogicalAccessChain( + GetElementPtrInst &GEP, + const std::function<void(Type *, uint64_t)> &OnLiteralIndexing, + const std::function<void(Type *, Value *)> &OnDynamicIndexing) { + // We only rewrite i8* GEP. Other should be left as-is. + // Valid i8* GEP must always have a single index. + assert(GEP.getSourceElementType() == + IntegerType::getInt8Ty(CurrF->getContext())); + assert(GEP.getNumIndices() == 1); + + auto &DL = CurrF->getDataLayout(); + Value *Src = getPointerRoot(GEP.getPointerOperand()); + Type *CurType = deduceElementType(Src, true); + + Value *Operand = *GEP.idx_begin(); + ConstantInt *CI = dyn_cast<ConstantInt>(Operand); + if (!CI) { + ArrayType *AT = dyn_cast<ArrayType>(CurType); + // Operand is not constant. Either we have an array and accept it, or we + // give up. + if (AT) + OnDynamicIndexing(AT->getElementType(), Operand); + return AT == nullptr; + } + + assert(CI); + uint64_t Offset = CI->getZExtValue(); + + do { + if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) { + uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8; + assert(Offset < AT->getNumElements() * EltTypeSize); + uint64_t Index = Offset / EltTypeSize; + Offset = Offset - (Index * EltTypeSize); + CurType = AT->getElementType(); + OnLiteralIndexing(CurType, Index); + } else if (StructType *ST = dyn_cast<StructType>(CurType)) { + uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8; + assert(Offset < StructSize); + (void)StructSize; + const auto &STL = DL.getStructLayout(ST); + unsigned Element = STL->getElementContainingOffset(Offset); + Offset -= STL->getElementOffset(Element); + CurType = ST->getElementType(Element); + OnLiteralIndexing(CurType, Element); + } else { + // Vector type indexing should not use GEP. + // So if we have an index left, something is wrong. Giving up. + return true; + } + } while (Offset > 0); + + return false; +} + +Instruction * +SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) { + auto &DL = CurrF->getDataLayout(); + IRBuilder<> B(GEP.getParent()); + B.SetInsertPoint(&GEP); + + std::vector<Value *> Indices; + Indices.push_back(ConstantInt::get( + IntegerType::getInt32Ty(CurrF->getContext()), 0, /* Signed= */ false)); + walkLogicalAccessChain( + GEP, + [&Indices, &B](Type *EltType, uint64_t Index) { + Indices.push_back( + ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false)); + }, + [&Indices, &B, &DL](Type *EltType, Value *Offset) { + uint32_t EltTypeSize = DL.getTypeSizeInBits(EltType) / 8; + Value *Index = B.CreateUDiv( + Offset, ConstantInt::get(Offset->getType(), EltTypeSize, + /* Signed= */ false)); + Indices.push_back(Index); + }); + + SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()}; + SmallVector<Value *, 4> Args; + Args.push_back(B.getInt1(GEP.isInBounds())); + Args.push_back(GEP.getOperand(0)); + llvm::append_range(Args, Indices); + auto *NewI = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args}); + replaceAllUsesWithAndErase(B, &GEP, NewI); + return NewI; +} + +Type *SPIRVEmitIntrinsics::getGEPTypeLogical(GetElementPtrInst *GEP) { + + Type *CurType = GEP->getResultElementType(); + + bool Interrupted = walkLogicalAccessChain( + *GEP, [&CurType](Type *EltType, uint64_t Index) { CurType = EltType; }, + [&CurType](Type *EltType, Value *Index) { CurType = EltType; }); + + return Interrupted ? GEP->getResultElementType() : CurType; +} + +Type *SPIRVEmitIntrinsics::getGEPType(GetElementPtrInst *Ref) { + if (Ref->getSourceElementType() == + IntegerType::getInt8Ty(CurrF->getContext()) && + TM->getSubtargetImpl()->isLogicalSPIRV()) { + return getGEPTypeLogical(Ref); + } + Type *Ty = nullptr; // TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything // useful here @@ -1395,6 +1547,13 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) { } Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) { + if (I.getSourceElementType() == IntegerType::getInt8Ty(CurrF->getContext()) && + TM->getSubtargetImpl()->isLogicalSPIRV()) { + Instruction *Result = buildLogicalAccessChainFromGEP(I); + if (Result) + return Result; + } + IRBuilder<> B(I.getParent()); B.SetInsertPoint(&I); SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()}; @@ -1588,7 +1747,24 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, } if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) { Value *Pointer = GEPI->getPointerOperand(); - Type *OpTy = GEPI->getSourceElementType(); + Type *OpTy = nullptr; + + // Knowing the accessed type is mandatory for logical SPIR-V. Sadly, + // the GEP source element type should not be used for this purpose, and + // the alternative type-scavenging method is not working. + // Physical SPIR-V can work around this, but not logical, hence still + // try to rely on the broken type scavenging for logical. + bool IsRewrittenGEP = + GEPI->getSourceElementType() == IntegerType::getInt8Ty(I->getContext()); + if (IsRewrittenGEP && TM->getSubtargetImpl()->isLogicalSPIRV()) { + Value *Src = getPointerRoot(Pointer); + OpTy = GR->findDeducedElementType(Src); + } + + // In all cases, fall back to the GEP type if type scavenging failed. + if (!OpTy) + OpTy = GEPI->getSourceElementType(); + replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B); if (isNestedPointer(OpTy)) insertTodoType(Pointer); diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h index 43bf6e9..60c4e2d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h @@ -59,6 +59,8 @@ public: Intrinsic::ID IID) const override; Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, Value *NewV) const override; + + bool allowVectorElementIndexingUsingGEP() const override { return false; } }; } // namespace llvm diff --git a/llvm/lib/Target/SystemZ/MCTargetDesc/SystemZMCAsmBackend.cpp b/llvm/lib/Target/SystemZ/MCTargetDesc/SystemZMCAsmBackend.cpp index d5f8492..b2cfd04 100644 --- a/llvm/lib/Target/SystemZ/MCTargetDesc/SystemZMCAsmBackend.cpp +++ b/llvm/lib/Target/SystemZ/MCTargetDesc/SystemZMCAsmBackend.cpp @@ -165,7 +165,7 @@ void SystemZMCAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned BitSize = getFixupKindInfo(Kind).TargetSize; unsigned Size = (BitSize + 7) / 8; - assert(Offset + Size <= Data.size() && "Invalid fixup offset!"); + assert(Offset + Size <= F.getSize() && "Invalid fixup offset!"); // Big-endian insertion of Size bytes. Value = extractBitsForFixup(Kind, Value, Fixup, getContext()); diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp index e30d723..fb0a47d 100644 --- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp +++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp @@ -9044,7 +9044,7 @@ static unsigned detectEvenOddMultiplyOperand(const SelectionDAG &DAG, if (unsigned(ShuffleMask[Elt]) != 2 * Elt) CanUseEven = false; if (unsigned(ShuffleMask[Elt]) != 2 * Elt + 1) - CanUseEven = true; + CanUseOdd = false; } Op = Op.getOperand(0); if (CanUseEven) diff --git a/llvm/lib/Target/VE/MCTargetDesc/VEAsmBackend.cpp b/llvm/lib/Target/VE/MCTargetDesc/VEAsmBackend.cpp index f987621..b02b6af 100644 --- a/llvm/lib/Target/VE/MCTargetDesc/VEAsmBackend.cpp +++ b/llvm/lib/Target/VE/MCTargetDesc/VEAsmBackend.cpp @@ -174,7 +174,7 @@ void VEAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned NumBytes = getFixupKindNumBytes(Fixup.getKind()); unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the bits // from the fixup value. The Value has been "split up" into the // appropriate bitfields above. diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp index 837fd8e..84eb15f 100644 --- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp +++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp @@ -97,7 +97,7 @@ void WebAssemblyAsmBackend::applyFixup(const MCFragment &F, Value <<= Info.TargetOffset; unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp index 7f9d474..1efef83 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp @@ -690,7 +690,7 @@ void X86AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, return; unsigned Size = getFixupKindSize(Kind); - assert(Fixup.getOffset() + Size <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + Size <= F.getSize() && "Invalid fixup offset!"); int64_t SignedValue = static_cast<int64_t>(Value); if (IsResolved && Fixup.isPCRel()) { diff --git a/llvm/lib/TargetParser/Host.cpp b/llvm/lib/TargetParser/Host.cpp index 78bd5b4..7e09d30 100644 --- a/llvm/lib/TargetParser/Host.cpp +++ b/llvm/lib/TargetParser/Host.cpp @@ -587,8 +587,9 @@ StringRef sys::detail::getHostCPUNameForBPF() { #endif } -#if defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || \ - defined(_M_X64) +#if (defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || \ + defined(_M_X64)) && \ + !defined(_M_ARM64EC) /// getX86CpuIDAndInfo - Execute the specified cpuid and return the 4 values in /// the specified arguments. If we can't run cpuid on the host, return true. @@ -1853,8 +1854,9 @@ VendorSignatures getVendorSignature(unsigned *MaxLeaf) { } // namespace llvm #endif -#if defined(__i386__) || defined(_M_IX86) || \ - defined(__x86_64__) || defined(_M_X64) +#if (defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || \ + defined(_M_X64)) && \ + !defined(_M_ARM64EC) StringMap<bool> sys::getHostCPUFeatures() { unsigned EAX = 0, EBX = 0, ECX = 0, EDX = 0; unsigned MaxLevel; @@ -2147,7 +2149,8 @@ StringMap<bool> sys::getHostCPUFeatures() { return Features; } -#elif defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64)) +#elif defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64) || \ + defined(__arm64ec__) || defined(_M_ARM64EC)) StringMap<bool> sys::getHostCPUFeatures() { StringMap<bool> Features; diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 7af5ba4..40a7f80 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -458,29 +458,19 @@ static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, // Check if this array of constants represents a cttz table. // Iterate over the elements from \p Table by trying to find/match all // the numbers from 0 to \p InputBits that should represent cttz results. -static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, - uint64_t Shift, uint64_t InputBits) { - unsigned Length = Table.getNumElements(); - if (Length < InputBits || Length > InputBits * 2) - return false; - - APInt Mask = APInt::getBitsSetFrom(InputBits, Shift); - unsigned Matched = 0; - - for (unsigned i = 0; i < Length; i++) { - uint64_t Element = Table.getElementAsInteger(i); - if (Element >= InputBits) - continue; - - // Check if \p Element matches a concrete answer. It could fail for some - // elements that are never accessed, so we keep iterating over each element - // from the table. The number of matched elements should be equal to the - // number of potential right answers which is \p InputBits actually. - if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i) - Matched++; +static bool isCTTZTable(Constant *Table, const APInt &Mul, const APInt &Shift, + const APInt &AndMask, Type *AccessTy, + unsigned InputBits, const APInt &GEPIdxFactor, + const DataLayout &DL) { + for (unsigned Idx = 0; Idx < InputBits; Idx++) { + APInt Index = (APInt(InputBits, 1).shl(Idx) * Mul).lshr(Shift) & AndMask; + ConstantInt *C = dyn_cast_or_null<ConstantInt>( + ConstantFoldLoadFromConst(Table, AccessTy, Index * GEPIdxFactor, DL)); + if (!C || C->getValue() != Idx) + return false; } - return Matched == InputBits; + return true; } // Try to recognize table-based ctz implementation. @@ -495,6 +485,11 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, // this can be lowered to `cttz` instruction. // There is also a special case when the element is 0. // +// The (x & -x) sets the lowest non-zero bit to 1. The multiply is a de-bruijn +// sequence that contains each pattern of bits in it. The shift extracts +// the top bits after the multiply, and that index into the table should +// represent the number of trailing zeros in the original number. +// // Here are some examples or LLVM IR for a 64-bit target: // // CASE 1: @@ -536,8 +531,8 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, // i64 %shr // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 // -// All this can be lowered to @llvm.cttz.i32/64 intrinsic. -static bool tryToRecognizeTableBasedCttz(Instruction &I) { +// All these can be lowered to @llvm.cttz.i32/64 intrinsics. +static bool tryToRecognizeTableBasedCttz(Instruction &I, const DataLayout &DL) { LoadInst *LI = dyn_cast<LoadInst>(&I); if (!LI) return false; @@ -547,53 +542,47 @@ static bool tryToRecognizeTableBasedCttz(Instruction &I) { return false; GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand()); - if (!GEP || !GEP->hasNoUnsignedSignedWrap() || GEP->getNumIndices() != 2) - return false; - - if (!GEP->getSourceElementType()->isArrayTy()) - return false; - - uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements(); - if (ArraySize != 32 && ArraySize != 64) + if (!GEP || !GEP->hasNoUnsignedSignedWrap()) return false; GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand()); if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant()) return false; - ConstantDataArray *ConstData = - dyn_cast<ConstantDataArray>(GVTable->getInitializer()); - if (!ConstData) - return false; - - if (!match(GEP->idx_begin()->get(), m_ZeroInt())) + unsigned BW = DL.getIndexTypeSizeInBits(GEP->getType()); + APInt ModOffset(BW, 0); + SmallMapVector<Value *, APInt, 4> VarOffsets; + if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset) || + VarOffsets.size() != 1 || ModOffset != 0) return false; + auto [GepIdx, GEPScale] = VarOffsets.front(); - Value *Idx2 = std::next(GEP->idx_begin())->get(); Value *X1; - uint64_t MulConst, ShiftConst; - // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will - // probably fail for other (e.g. 32-bit) targets. - if (!match(Idx2, m_ZExtOrSelf( - m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)), - m_ConstantInt(MulConst)), - m_ConstantInt(ShiftConst))))) + const APInt *MulConst, *ShiftConst, *AndCst = nullptr; + // Check that the gep variable index is ((x & -x) * MulConst) >> ShiftConst. + // This might be extended to the pointer index type, and if the gep index type + // has been replaced with an i8 then a new And (and different ShiftConst) will + // be present. + auto MatchInner = m_LShr( + m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)), m_APInt(MulConst)), + m_APInt(ShiftConst)); + if (!match(GepIdx, m_CastOrSelf(MatchInner)) && + !match(GepIdx, m_CastOrSelf(m_And(MatchInner, m_APInt(AndCst))))) return false; unsigned InputBits = X1->getType()->getScalarSizeInBits(); - if (InputBits != 32 && InputBits != 64) - return false; - - // Shift should extract top 5..7 bits. - if (InputBits - Log2_32(InputBits) != ShiftConst && - InputBits - Log2_32(InputBits) - 1 != ShiftConst) + if (InputBits != 16 && InputBits != 32 && InputBits != 64 && InputBits != 128) return false; - if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits)) + if (!GEPScale.isIntN(InputBits) || + !isCTTZTable(GVTable->getInitializer(), *MulConst, *ShiftConst, + AndCst ? *AndCst : APInt::getAllOnes(InputBits), AccessType, + InputBits, GEPScale.zextOrTrunc(InputBits), DL)) return false; - auto ZeroTableElem = ConstData->getElementAsInteger(0); - bool DefinedForZero = ZeroTableElem == InputBits; + ConstantInt *ZeroTableElem = cast<ConstantInt>( + ConstantFoldLoadFromConst(GVTable->getInitializer(), AccessType, DL)); + bool DefinedForZero = ZeroTableElem->getZExtValue() == InputBits; IRBuilder<> B(LI); ConstantInt *BoolConst = B.getInt1(!DefinedForZero); @@ -607,8 +596,7 @@ static bool tryToRecognizeTableBasedCttz(Instruction &I) { // If the value in elem 0 isn't the same as InputBits, we still want to // produce the value from the table. auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0)); - auto Select = - B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz); + auto Select = B.CreateSelect(Cmp, B.CreateZExt(ZeroTableElem, XType), Cttz); // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target // it should be handled as: `cttz(x) & (typeSize - 1)`. @@ -1477,7 +1465,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, MadeChange |= foldGuardedFunnelShift(I, DT); MadeChange |= tryToRecognizePopCount(I); MadeChange |= tryToFPToSat(I, TTI); - MadeChange |= tryToRecognizeTableBasedCttz(I); + MadeChange |= tryToRecognizeTableBasedCttz(I, DL); MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); MadeChange |= foldPatternedLoads(I, DL); MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT); diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 3c24d2e..01da012 100644 --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -13404,7 +13404,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo { return indicatePessimisticFixpoint(); if (BinSize == 0) { - auto NewAllocationSize = std::optional<TypeSize>(TypeSize(0, false)); + auto NewAllocationSize = std::make_optional<TypeSize>(0, false); if (!changeAllocationSize(NewAllocationSize)) return ChangeStatus::UNCHANGED; return ChangeStatus::CHANGED; @@ -13422,8 +13422,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo { if (SizeOfBin >= *AllocationSize) return indicatePessimisticFixpoint(); - auto NewAllocationSize = - std::optional<TypeSize>(TypeSize(SizeOfBin * 8, false)); + auto NewAllocationSize = std::make_optional<TypeSize>(SizeOfBin * 8, false); if (!changeAllocationSize(NewAllocationSize)) return ChangeStatus::UNCHANGED; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 00b877b..fe0f308 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -462,6 +462,13 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { return ScalarPHI; } + // If SrcVec is a subvector starting at index 0, extract from the + // wider source vector + Value *V; + if (match(SrcVec, + m_Intrinsic<Intrinsic::vector_extract>(m_Value(V), m_Zero()))) + return ExtractElementInst::Create(V, Index); + // TODO come up with a n-ary matcher that subsumes both unary and // binary matchers. UnaryOperator *UO; diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index b992597..5ee3bb1 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -2011,12 +2011,17 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN, NewPN->addIncoming(NewPhiValues[i], PN->getIncomingBlock(i)); if (IdenticalUsers) { - for (User *U : make_early_inc_range(PN->users())) { + // Collect and deduplicate users up-front to avoid iterator invalidation. + SmallSetVector<Instruction *, 4> ToReplace; + for (User *U : PN->users()) { Instruction *User = cast<Instruction>(U); if (User == &I) continue; - replaceInstUsesWith(*User, NewPN); - eraseInstFromFunction(*User); + ToReplace.insert(User); + } + for (Instruction *I : ToReplace) { + replaceInstUsesWith(*I, NewPN); + eraseInstFromFunction(*I); } OneUse = true; } @@ -2654,9 +2659,18 @@ static Instruction *canonicalizeGEPOfConstGEPI8(GetElementPtrInst &GEP, APInt NewOffset = TypeSize * *C2 + *C1; if (NewOffset.isZero() || (Src->hasOneUse() && GEP.getOperand(1)->hasOneUse())) { + GEPNoWrapFlags Flags = GEPNoWrapFlags::none(); + if (GEP.hasNoUnsignedWrap() && + cast<GEPOperator>(Src)->hasNoUnsignedWrap() && + match(GEP.getOperand(1), m_NUWAddLike(m_Value(), m_Value()))) { + Flags |= GEPNoWrapFlags::noUnsignedWrap(); + if (GEP.isInBounds() && cast<GEPOperator>(Src)->isInBounds()) + Flags |= GEPNoWrapFlags::inBounds(); + } + Value *GEPConst = - IC.Builder.CreatePtrAdd(Base, IC.Builder.getInt(NewOffset)); - return GetElementPtrInst::Create(BaseType, GEPConst, VarIndex); + IC.Builder.CreatePtrAdd(Base, IC.Builder.getInt(NewOffset), "", Flags); + return GetElementPtrInst::Create(BaseType, GEPConst, VarIndex, Flags); } return nullptr; @@ -3184,7 +3198,16 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { // If we are using a wider index than needed for this platform, shrink // it to what we need. If narrower, sign-extend it to what we need. // This explicit cast can make subsequent optimizations more obvious. - *I = Builder.CreateIntCast(*I, NewIndexType, true); + if (IndexTy->getScalarSizeInBits() < + NewIndexType->getScalarSizeInBits()) { + if (GEP.hasNoUnsignedWrap() && GEP.hasNoUnsignedSignedWrap()) + *I = Builder.CreateZExt(*I, NewIndexType, "", /*IsNonNeg=*/true); + else + *I = Builder.CreateSExt(*I, NewIndexType); + } else { + *I = Builder.CreateTrunc(*I, NewIndexType, "", GEP.hasNoUnsignedWrap(), + GEP.hasNoUnsignedSignedWrap()); + } MadeChange = true; } } diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 4e5a8d1..bcb90d6 100644 --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -160,6 +160,16 @@ static cl::opt<bool> ClGenerateTagsWithCalls( static cl::opt<bool> ClGlobals("hwasan-globals", cl::desc("Instrument globals"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClAllGlobals( + "hwasan-all-globals", + cl::desc( + "Instrument globals, even those within user-defined sections. Warning: " + "This may break existing code which walks globals via linker-generated " + "symbols, expects certain globals to be contiguous with each other, or " + "makes other assumptions which are invalidated by HWASan " + "instrumentation."), + cl::Hidden, cl::init(false)); + static cl::opt<int> ClMatchAllTag( "hwasan-match-all-tag", cl::desc("don't report bad accesses via pointers with this tag"), @@ -681,11 +691,11 @@ void HWAddressSanitizer::initializeModule() { !CompileKernel && !UsePageAliases && optOr(ClGlobals, NewRuntime); if (!CompileKernel) { - createHwasanCtorComdat(); - if (InstrumentGlobals) instrumentGlobals(); + createHwasanCtorComdat(); + bool InstrumentPersonalityFunctions = optOr(ClInstrumentPersonalityFunctions, NewRuntime); if (InstrumentPersonalityFunctions) @@ -1772,11 +1782,17 @@ void HWAddressSanitizer::instrumentGlobals() { if (GV.hasCommonLinkage()) continue; - // Globals with custom sections may be used in __start_/__stop_ enumeration, - // which would be broken both by adding tags and potentially by the extra - // padding/alignment that we insert. - if (GV.hasSection()) - continue; + if (ClAllGlobals) { + // Avoid instrumenting intrinsic global variables. + if (GV.getSection() == "llvm.metadata") + continue; + } else { + // Globals with custom sections may be used in __start_/__stop_ + // enumeration, which would be broken both by adding tags and potentially + // by the extra padding/alignment that we insert. + if (GV.hasSection()) + continue; + } Globals.push_back(&GV); } diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index 68094c3..c3f80f9 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -2508,6 +2508,12 @@ static bool hoistGEP(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, if (!GEP) return false; + // Do not try to hoist a constant GEP out of the loop via reassociation. + // Constant GEPs can often be folded into addressing modes, and reassociating + // them may inhibit CSE of a common base. + if (GEP->hasAllConstantIndices()) + return false; + auto *Src = dyn_cast<GetElementPtrInst>(GEP->getPointerOperand()); if (!Src || !Src->hasOneUse() || !L.contains(Src)) return false; diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index f3e992c..04039b8 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -1009,7 +1009,8 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. - LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr); + LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr, + &AR.AC); for (Loop *InnerLoop : LN.getLoops()) { auto *OuterLoop = InnerLoop->getParentLoop(); if (!OuterLoop) diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 4f2bfb0..448dc2b 100644 --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -551,7 +551,7 @@ PreservedAnalyses LoopVersioningLICMPass::run(Loop &L, LoopAnalysisManager &AM, const Function *F = L.getHeader()->getParent(); OptimizationRemarkEmitter ORE(F); - LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr, nullptr); + LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr, nullptr, &LAR.AC); if (!LoopVersioningLICM(AA, SE, &ORE, LAIs, LAR.LI, &L).run(DT)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 320b792..6ffe841 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -79,8 +79,7 @@ // ld.global.f32 %f4, [%rl6+132]; // much better // // Another improvement enabled by the LowerGEP flag is to lower a GEP with -// multiple indices to either multiple GEPs with a single index or arithmetic -// operations (depending on whether the target uses alias analysis in codegen). +// multiple indices to multiple GEPs with a single index. // Such transformation can have following benefits: // (1) It can always extract constants in the indices of structure type. // (2) After such Lowering, there are more optimization opportunities such as @@ -88,59 +87,33 @@ // // E.g. The following GEPs have multiple indices: // BB1: -// %p = getelementptr [10 x %struct]* %ptr, i64 %i, i64 %j1, i32 3 +// %p = getelementptr [10 x %struct], ptr %ptr, i64 %i, i64 %j1, i32 3 // load %p // ... // BB2: -// %p2 = getelementptr [10 x %struct]* %ptr, i64 %i, i64 %j1, i32 2 +// %p2 = getelementptr [10 x %struct], ptr %ptr, i64 %i, i64 %j1, i32 2 // load %p2 // ... // // We can not do CSE to the common part related to index "i64 %i". Lowering // GEPs can achieve such goals. -// If the target does not use alias analysis in codegen, this pass will -// lower a GEP with multiple indices into arithmetic operations: -// BB1: -// %1 = ptrtoint [10 x %struct]* %ptr to i64 ; CSE opportunity -// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity -// %3 = add i64 %1, %2 ; CSE opportunity -// %4 = mul i64 %j1, length_of_struct -// %5 = add i64 %3, %4 -// %6 = add i64 %3, struct_field_3 ; Constant offset -// %p = inttoptr i64 %6 to i32* -// load %p -// ... -// BB2: -// %7 = ptrtoint [10 x %struct]* %ptr to i64 ; CSE opportunity -// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity -// %9 = add i64 %7, %8 ; CSE opportunity -// %10 = mul i64 %j2, length_of_struct -// %11 = add i64 %9, %10 -// %12 = add i64 %11, struct_field_2 ; Constant offset -// %p = inttoptr i64 %12 to i32* -// load %p2 -// ... // -// If the target uses alias analysis in codegen, this pass will lower a GEP -// with multiple indices into multiple GEPs with a single index: +// This pass will lower a GEP with multiple indices into multiple GEPs with a +// single index: // BB1: -// %1 = bitcast [10 x %struct]* %ptr to i8* ; CSE opportunity -// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity -// %3 = getelementptr i8* %1, i64 %2 ; CSE opportunity +// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %3 = getelementptr i8, ptr %ptr, i64 %2 ; CSE opportunity // %4 = mul i64 %j1, length_of_struct -// %5 = getelementptr i8* %3, i64 %4 -// %6 = getelementptr i8* %5, struct_field_3 ; Constant offset -// %p = bitcast i8* %6 to i32* +// %5 = getelementptr i8, ptr %3, i64 %4 +// %p = getelementptr i8, ptr %5, struct_field_3 ; Constant offset // load %p // ... // BB2: -// %7 = bitcast [10 x %struct]* %ptr to i8* ; CSE opportunity -// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity -// %9 = getelementptr i8* %7, i64 %8 ; CSE opportunity +// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %9 = getelementptr i8, ptr %ptr, i64 %8 ; CSE opportunity // %10 = mul i64 %j2, length_of_struct -// %11 = getelementptr i8* %9, i64 %10 -// %12 = getelementptr i8* %11, struct_field_2 ; Constant offset -// %p2 = bitcast i8* %12 to i32* +// %11 = getelementptr i8, ptr %9, i64 %10 +// %p2 = getelementptr i8, ptr %11, struct_field_2 ; Constant offset // load %p2 // ... // @@ -408,16 +381,6 @@ private: void lowerToSingleIndexGEPs(GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset); - /// Lower a GEP with multiple indices into ptrtoint+arithmetics+inttoptr form. - /// Function splitGEP already split the original GEP into a variadic part and - /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the - /// variadic part into a set of arithmetic operations and applies - /// AccumulativeByteOffset to it. - /// \p Variadic The variadic part of the original GEP. - /// \p AccumulativeByteOffset The constant offset. - void lowerToArithmetics(GetElementPtrInst *Variadic, - int64_t AccumulativeByteOffset); - /// Finds the constant offset within each index and accumulates them. If /// LowerGEP is true, it finds in indices of both sequential and structure /// types, otherwise it only finds in sequential indices. The output @@ -951,55 +914,6 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( Variadic->eraseFromParent(); } -void -SeparateConstOffsetFromGEP::lowerToArithmetics(GetElementPtrInst *Variadic, - int64_t AccumulativeByteOffset) { - IRBuilder<> Builder(Variadic); - Type *IntPtrTy = DL->getIntPtrType(Variadic->getType()); - assert(IntPtrTy == DL->getIndexType(Variadic->getType()) && - "Pointer type must match index type for arithmetic-based lowering of " - "split GEPs"); - - Value *ResultPtr = Builder.CreatePtrToInt(Variadic->getOperand(0), IntPtrTy); - gep_type_iterator GTI = gep_type_begin(*Variadic); - // Create ADD/SHL/MUL arithmetic operations for each sequential indices. We - // don't create arithmetics for structure indices, as they are accumulated - // in the constant offset index. - for (unsigned I = 1, E = Variadic->getNumOperands(); I != E; ++I, ++GTI) { - if (GTI.isSequential()) { - Value *Idx = Variadic->getOperand(I); - // Skip zero indices. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) - if (CI->isZero()) - continue; - - APInt ElementSize = APInt(IntPtrTy->getIntegerBitWidth(), - GTI.getSequentialElementStride(*DL)); - // Scale the index by element size. - if (ElementSize != 1) { - if (ElementSize.isPowerOf2()) { - Idx = Builder.CreateShl( - Idx, ConstantInt::get(IntPtrTy, ElementSize.logBase2())); - } else { - Idx = Builder.CreateMul(Idx, ConstantInt::get(IntPtrTy, ElementSize)); - } - } - // Create an ADD for each index. - ResultPtr = Builder.CreateAdd(ResultPtr, Idx); - } - } - - // Create an ADD for the constant offset index. - if (AccumulativeByteOffset != 0) { - ResultPtr = Builder.CreateAdd( - ResultPtr, ConstantInt::get(IntPtrTy, AccumulativeByteOffset)); - } - - ResultPtr = Builder.CreateIntToPtr(ResultPtr, Variadic->getType()); - Variadic->replaceAllUsesWith(ResultPtr); - Variadic->eraseFromParent(); -} - bool SeparateConstOffsetFromGEP::reorderGEP(GetElementPtrInst *GEP, TargetTransformInfo &TTI) { auto PtrGEP = dyn_cast<GetElementPtrInst>(GEP->getPointerOperand()); @@ -1091,8 +1005,8 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // Notice that we don't remove struct field indices here. If LowerGEP is // disabled, a structure index is not accumulated and we still use the old // one. If LowerGEP is enabled, a structure index is accumulated in the - // constant offset. LowerToSingleIndexGEPs or lowerToArithmetics will later - // handle the constant offset and won't need a new structure index. + // constant offset. LowerToSingleIndexGEPs will later handle the constant + // offset and won't need a new structure index. gep_type_iterator GTI = gep_type_begin(*GEP); for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { if (GTI.isSequential()) { @@ -1167,22 +1081,9 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { GEP->setNoWrapFlags(NewGEPFlags); - // Lowers a GEP to either GEPs with a single index or arithmetic operations. + // Lowers a GEP to GEPs with a single index. if (LowerGEP) { - // As currently BasicAA does not analyze ptrtoint/inttoptr, do not lower to - // arithmetic operations if the target uses alias analysis in codegen. - // Additionally, pointers that aren't integral (and so can't be safely - // converted to integers) or those whose offset size is different from their - // pointer size (which means that doing integer arithmetic on them could - // affect that data) can't be lowered in this way. - unsigned AddrSpace = GEP->getPointerAddressSpace(); - bool PointerHasExtraData = DL->getPointerSizeInBits(AddrSpace) != - DL->getIndexSizeInBits(AddrSpace); - if (TTI.useAA() || DL->isNonIntegralAddressSpace(AddrSpace) || - PointerHasExtraData) - lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset); - else - lowerToArithmetics(GEP, AccumulativeByteOffset); + lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset); return true; } diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 571fa11..1eb8996 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1249,7 +1249,8 @@ Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) { // offset, if the offset is simpler. const SCEV *Diff = SE.getMinusSCEV(S, ExitSCEV); const SCEV *Op = Diff; - match(Diff, m_scev_Mul(m_scev_AllOnes(), m_SCEV(Op))); + match(Op, m_scev_Add(m_SCEVConstant(), m_SCEV(Op))); + match(Op, m_scev_Mul(m_scev_AllOnes(), m_SCEV(Op))); match(Op, m_scev_PtrToInt(m_SCEV(Op))); if (!isa<SCEVConstant, SCEVUnknown>(Op)) continue; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 94b0ab8..674de57 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -198,6 +198,11 @@ static cl::opt<unsigned> MaxSwitchCasesPerResult( "max-switch-cases-per-result", cl::Hidden, cl::init(16), cl::desc("Limit cases to analyze when converting a switch to select")); +static cl::opt<unsigned> MaxJumpThreadingLiveBlocks( + "max-jump-threading-live-blocks", cl::Hidden, cl::init(24), + cl::desc("Limit number of blocks a define in a threaded block is allowed " + "to be live in")); + STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); @@ -3390,8 +3395,27 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI, return true; } +using BlocksSet = SmallPtrSet<BasicBlock *, 8>; + +// Return false if number of blocks searched is too much. +static bool findReaching(BasicBlock *BB, BasicBlock *DefBB, + BlocksSet &ReachesNonLocalUses) { + if (BB == DefBB) + return true; + if (!ReachesNonLocalUses.insert(BB).second) + return true; + + if (ReachesNonLocalUses.size() > MaxJumpThreadingLiveBlocks) + return false; + for (BasicBlock *Pred : predecessors(BB)) + if (!findReaching(Pred, DefBB, ReachesNonLocalUses)) + return false; + return true; +} + /// Return true if we can thread a branch across this block. -static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { +static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB, + BlocksSet &NonLocalUseBlocks) { int Size = 0; EphemeralValueTracker EphTracker; @@ -3411,12 +3435,16 @@ static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { return false; // Don't clone large BB's. } - // We can only support instructions that do not define values that are - // live outside of the current basic block. + // Record blocks with non-local uses of values defined in the current basic + // block. for (User *U : I.users()) { Instruction *UI = cast<Instruction>(U); - if (UI->getParent() != BB || isa<PHINode>(UI)) - return false; + BasicBlock *UsedInBB = UI->getParent(); + if (UsedInBB == BB) { + if (isa<PHINode>(UI)) + return false; + } else + NonLocalUseBlocks.insert(UsedInBB); } // Looks ok, continue checking. @@ -3475,18 +3503,37 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, return false; // Now we know that this block has multiple preds and two succs. - // Check that the block is small enough and values defined in the block are - // not used outside of it. - if (!blockIsSimpleEnoughToThreadThrough(BB)) + // Check that the block is small enough and record which non-local blocks use + // values defined in the block. + + BlocksSet NonLocalUseBlocks; + BlocksSet ReachesNonLocalUseBlocks; + if (!blockIsSimpleEnoughToThreadThrough(BB, NonLocalUseBlocks)) return false; + // Jump-threading can only be done to destinations where no values defined + // in BB are live. + + // Quickly check if both destinations have uses. If so, jump-threading cannot + // be done. + if (NonLocalUseBlocks.contains(BI->getSuccessor(0)) && + NonLocalUseBlocks.contains(BI->getSuccessor(1))) + return false; + + // Search backward from NonLocalUseBlocks to find which blocks + // reach non-local uses. + for (BasicBlock *UseBB : NonLocalUseBlocks) + // Give up if too many blocks are searched. + if (!findReaching(UseBB, BB, ReachesNonLocalUseBlocks)) + return false; + for (const auto &Pair : KnownValues) { - // Okay, we now know that all edges from PredBB should be revectored to - // branch to RealDest. ConstantInt *CB = Pair.first; ArrayRef<BasicBlock *> PredBBs = Pair.second.getArrayRef(); BasicBlock *RealDest = BI->getSuccessor(!CB->getZExtValue()); + // Okay, we now know that all edges from PredBB should be revectored to + // branch to RealDest. if (RealDest == BB) continue; // Skip self loops. @@ -3496,6 +3543,10 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, })) continue; + // Only revector to RealDest if no values defined in BB are live. + if (ReachesNonLocalUseBlocks.contains(RealDest)) + continue; + LLVM_DEBUG({ dbgs() << "Condition " << *Cond << " in " << BB->getName() << " has value " << *Pair.first << " in predecessors:\n"; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 969d225..c47fd942 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -1665,13 +1665,12 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { // Keep a record of all the exiting blocks. SmallVector<const SCEVPredicate *, 4> Predicates; - std::optional<std::pair<BasicBlock *, BasicBlock *>> SingleUncountableEdge; + BasicBlock *SingleUncountableExitingBlock = nullptr; for (BasicBlock *BB : ExitingBlocks) { const SCEV *EC = PSE.getSE()->getPredicatedExitCount(TheLoop, BB, &Predicates); if (isa<SCEVCouldNotCompute>(EC)) { - SmallVector<BasicBlock *, 2> Succs(successors(BB)); - if (Succs.size() != 2) { + if (size(successors(BB)) != 2) { reportVectorizationFailure( "Early exiting block does not have exactly two successors", "Incorrect number of successors from early exiting block", @@ -1679,15 +1678,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { return false; } - BasicBlock *ExitBlock; - if (!TheLoop->contains(Succs[0])) - ExitBlock = Succs[0]; - else { - assert(!TheLoop->contains(Succs[1])); - ExitBlock = Succs[1]; - } - - if (SingleUncountableEdge) { + if (SingleUncountableExitingBlock) { reportVectorizationFailure( "Loop has too many uncountable exits", "Cannot vectorize early exit loop with more than one early exit", @@ -1695,7 +1686,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { return false; } - SingleUncountableEdge = {BB, ExitBlock}; + SingleUncountableExitingBlock = BB; } else CountableExitingBlocks.push_back(BB); } @@ -1705,7 +1696,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { // PSE.getSymbolicMaxBackedgeTakenCount() below. Predicates.clear(); - if (!SingleUncountableEdge) { + if (!SingleUncountableExitingBlock) { LLVM_DEBUG(dbgs() << "LV: Cound not find any uncountable exits"); return false; } @@ -1713,7 +1704,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { // The only supported early exit loops so far are ones where the early // exiting block is a unique predecessor of the latch block. BasicBlock *LatchPredBB = LatchBB->getUniquePredecessor(); - if (LatchPredBB != SingleUncountableEdge->first) { + if (LatchPredBB != SingleUncountableExitingBlock) { reportVectorizationFailure("Early exit is not the latch predecessor", "Cannot vectorize early exit loop", "EarlyExitNotLatchPredecessor", ORE, TheLoop); @@ -1766,7 +1757,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { } // The vectoriser cannot handle loads that occur after the early exit block. - assert(LatchBB->getUniquePredecessor() == SingleUncountableEdge->first && + assert(LatchBB->getUniquePredecessor() == SingleUncountableExitingBlock && "Expected latch predecessor to be the early exiting block"); // TODO: Handle loops that may fault. @@ -1789,7 +1780,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { LLVM_DEBUG(dbgs() << "LV: Found an early exit loop with symbolic max " "backedge taken count: " << *SymbolicMaxBTC << '\n'); - UncountableEdge = SingleUncountableEdge; + UncountableExitingBB = SingleUncountableExitingBlock; return true; } @@ -1861,7 +1852,8 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { return false; } else { if (!isVectorizableEarlyExitLoop()) { - UncountableEdge = std::nullopt; + assert(!hasUncountableEarlyExit() && + "Must be false without vectorizable early-exit loop"); if (DoExtraAnalysis) Result = false; else diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 850c4a1..2052808 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1624,7 +1624,7 @@ private: /// presence of a cost for an instruction in the mapping indicates that the /// instruction will be scalarized when vectorizing with the associated /// vectorization factor. The entries are VF-ScalarCostTy pairs. - DenseMap<ElementCount, ScalarCostsTy> InstsToScalarize; + MapVector<ElementCount, ScalarCostsTy> InstsToScalarize; /// Holds the instructions known to be uniform after vectorization. /// The data is collected per VF. diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 68e7c20..11b4677 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2530,8 +2530,8 @@ void VPReductionRecipe::execute(VPTransformState &State) { NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain); else NextInChain = State.Builder.CreateBinOp( - (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), NewRed, - PrevInChain); + (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), + PrevInChain, NewRed); } State.set(this, NextInChain, /*IsScalar*/ true); } @@ -3548,6 +3548,8 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { // Vectorize the interleaved store group. Value *MaskForGaps = createBitMaskForGaps(State.Builder, State.VF.getKnownMinValue(), *Group); + assert(((MaskForGaps != nullptr) == NeedsMaskForGaps) && + "Mismatch between NeedsMaskForGaps and MaskForGaps"); assert((!MaskForGaps || !State.VF.isScalable()) && "masking gaps for scalable vectors is not yet supported."); ArrayRef<VPValue *> StoredValues = getStoredValues(); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index ad8235d..fcbc86f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -32,11 +32,11 @@ #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/TypeSize.h" using namespace llvm; +using namespace VPlanPatternMatch; bool VPlanTransforms::tryToConvertVPInstructionsToVPRecipes( VPlanPtr &Plan, @@ -528,13 +528,11 @@ static void removeRedundantCanonicalIVs(VPlan &Plan) { /// Returns true if \p R is dead and can be removed. static bool isDeadRecipe(VPRecipeBase &R) { - using namespace llvm::PatternMatch; // Do remove conditional assume instructions as their conditions may be // flattened. auto *RepR = dyn_cast<VPReplicateRecipe>(&R); - bool IsConditionalAssume = - RepR && RepR->isPredicated() && - match(RepR->getUnderlyingInstr(), m_Intrinsic<Intrinsic::assume>()); + bool IsConditionalAssume = RepR && RepR->isPredicated() && + match(RepR, m_Intrinsic<Intrinsic::assume>()); if (IsConditionalAssume) return true; @@ -625,7 +623,6 @@ static SmallVector<VPUser *> collectUsersRecursively(VPValue *V) { /// original IV's users. This is an optional optimization to reduce the needs of /// vector extracts. static void legalizeAndOptimizeInductions(VPlan &Plan) { - using namespace llvm::VPlanPatternMatch; VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); bool HasOnlyVectorVFs = !Plan.hasScalarVFOnly(); VPBuilder Builder(HeaderVPBB, HeaderVPBB->getFirstNonPhi()); @@ -727,7 +724,6 @@ static VPWidenInductionRecipe *getOptimizableIVOf(VPValue *VPV) { return nullptr; auto IsWideIVInc = [&]() { - using namespace VPlanPatternMatch; auto &ID = WideIV->getInductionDescriptor(); // Check if VPV increments the induction by the induction step. @@ -771,8 +767,6 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan &Plan, VPTypeAnalysis &TypeInfo, VPBlockBase *PredVPBB, VPValue *Op) { - using namespace VPlanPatternMatch; - VPValue *Incoming, *Mask; if (!match(Op, m_VPInstruction<VPInstruction::ExtractLane>( m_VPInstruction<VPInstruction::FirstActiveLane>( @@ -827,8 +821,6 @@ static VPValue * optimizeLatchExitInductionUser(VPlan &Plan, VPTypeAnalysis &TypeInfo, VPBlockBase *PredVPBB, VPValue *Op, DenseMap<VPValue *, VPValue *> &EndValues) { - using namespace VPlanPatternMatch; - VPValue *Incoming; if (!match(Op, m_VPInstruction<VPInstruction::ExtractLastElement>( m_VPValue(Incoming)))) @@ -986,7 +978,6 @@ static Value *tryToFoldLiveIns(const VPRecipeBase &R, unsigned Opcode, /// Try to simplify recipe \p R. static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { - using namespace llvm::VPlanPatternMatch; VPlan *Plan = R.getParent()->getPlan(); auto *Def = dyn_cast<VPSingleDefRecipe>(&R); @@ -1269,7 +1260,6 @@ static void narrowToSingleScalarRecipes(VPlan &Plan) { /// Normalize and simplify VPBlendRecipes. Should be run after simplifyRecipes /// to make sure the masks are simplified. static void simplifyBlends(VPlan &Plan) { - using namespace llvm::VPlanPatternMatch; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( vp_depth_first_shallow(Plan.getVectorLoopRegion()->getEntry()))) { for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { @@ -1393,7 +1383,6 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan, // Currently only handle cases where the single user is a header-mask // comparison with the backedge-taken-count. - using namespace VPlanPatternMatch; if (!match( *WideIV->user_begin(), m_Binary<Instruction::ICmp>( @@ -1424,8 +1413,7 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan, static bool isConditionTrueViaVFAndUF(VPValue *Cond, VPlan &Plan, ElementCount BestVF, unsigned BestUF, ScalarEvolution &SE) { - using namespace llvm::VPlanPatternMatch; - if (match(Cond, m_Binary<Instruction::Or>(m_VPValue(), m_VPValue()))) + if (match(Cond, m_BinaryOr(m_VPValue(), m_VPValue()))) return any_of(Cond->getDefiningRecipe()->operands(), [&Plan, BestVF, BestUF, &SE](VPValue *C) { return isConditionTrueViaVFAndUF(C, Plan, BestVF, BestUF, SE); @@ -1464,7 +1452,6 @@ static bool simplifyBranchConditionForVFAndUF(VPlan &Plan, ElementCount BestVF, auto *Term = &ExitingVPBB->back(); VPValue *Cond; ScalarEvolution &SE = *PSE.getSE(); - using namespace llvm::VPlanPatternMatch; if (match(Term, m_BranchOnCount(m_VPValue(), m_VPValue())) || match(Term, m_BranchOnCond( m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue()))))) { @@ -1496,11 +1483,11 @@ static bool simplifyBranchConditionForVFAndUF(VPlan &Plan, ElementCount BestVF, auto *CanIVTy = Plan.getCanonicalIV()->getScalarType(); if (all_of(Header->phis(), IsaPred<VPCanonicalIVPHIRecipe, VPEVLBasedIVPHIRecipe, - VPFirstOrderRecurrencePHIRecipe>)) { + VPFirstOrderRecurrencePHIRecipe, VPPhi>)) { for (VPRecipeBase &HeaderR : make_early_inc_range(Header->phis())) { - auto *HeaderPhiR = cast<VPHeaderPHIRecipe>(&HeaderR); - HeaderPhiR->replaceAllUsesWith(HeaderPhiR->getStartValue()); - HeaderPhiR->eraseFromParent(); + auto *Phi = cast<VPPhiAccessors>(&HeaderR); + HeaderR.getVPSingleValue()->replaceAllUsesWith(Phi->getIncomingValue(0)); + HeaderR.eraseFromParent(); } VPBlockBase *Preheader = VectorRegion->getSinglePredecessor(); @@ -1847,7 +1834,6 @@ void VPlanTransforms::truncateToMinimalBitwidths( if (auto *VPW = dyn_cast<VPRecipeWithIRFlags>(&R)) VPW->dropPoisonGeneratingFlags(); - using namespace llvm::VPlanPatternMatch; if (OldResSizeInBits != NewResSizeInBits && !match(&R, m_Binary<Instruction::ICmp>(m_VPValue(), m_VPValue()))) { // Extend result to original width. @@ -1897,7 +1883,6 @@ void VPlanTransforms::truncateToMinimalBitwidths( } void VPlanTransforms::removeBranchOnConst(VPlan &Plan) { - using namespace llvm::VPlanPatternMatch; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( vp_depth_first_shallow(Plan.getEntry()))) { VPValue *Cond; @@ -2143,7 +2128,6 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask, VPRecipeBase &CurRecipe, VPTypeAnalysis &TypeInfo, VPValue &AllOneMask, VPValue &EVL) { - using namespace llvm::VPlanPatternMatch; auto GetNewMask = [&](VPValue *OrigMask) -> VPValue * { assert(OrigMask && "Unmasked recipe when folding tail"); // HeaderMask will be handled using EVL. @@ -2223,7 +2207,6 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( vp_depth_first_deep(Plan.getVectorLoopRegion()->getEntry()))) { for (VPRecipeBase &R : *VPBB) { - using namespace VPlanPatternMatch; VPValue *V1, *V2; if (!match(&R, m_VPInstruction<VPInstruction::FirstOrderRecurrenceSplice>( @@ -2309,10 +2292,12 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { /// ... /// %EVLPhi = EXPLICIT-VECTOR-LENGTH-BASED-IV-PHI [ %StartV, %vector.ph ], /// [ %NextEVLIV, %vector.body ] -/// %AVL = sub original TC, %EVLPhi +/// %AVL = phi [ trip-count, %vector.ph ], [ %NextAVL, %vector.body ] /// %VPEVL = EXPLICIT-VECTOR-LENGTH %AVL /// ... -/// %NextEVLIV = add IVSize (cast i32 %VPEVVL to IVSize), %EVLPhi +/// %OpEVL = cast i32 %VPEVL to IVSize +/// %NextEVLIV = add IVSize %OpEVL, %EVLPhi +/// %NextAVL = sub IVSize nuw %AVL, %OpEVL /// ... /// /// If MaxSafeElements is provided, the function adds the following recipes: @@ -2323,12 +2308,14 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { /// ... /// %EVLPhi = EXPLICIT-VECTOR-LENGTH-BASED-IV-PHI [ %StartV, %vector.ph ], /// [ %NextEVLIV, %vector.body ] -/// %AVL = sub original TC, %EVLPhi +/// %AVL = phi [ trip-count, %vector.ph ], [ %NextAVL, %vector.body ] /// %cmp = cmp ult %AVL, MaxSafeElements /// %SAFE_AVL = select %cmp, %AVL, MaxSafeElements /// %VPEVL = EXPLICIT-VECTOR-LENGTH %SAFE_AVL /// ... -/// %NextEVLIV = add IVSize (cast i32 %VPEVL to IVSize), %EVLPhi +/// %OpEVL = cast i32 %VPEVL to IVSize +/// %NextEVLIV = add IVSize %OpEVL, %EVLPhi +/// %NextAVL = sub IVSize nuw %AVL, %OpEVL /// ... /// bool VPlanTransforms::tryAddExplicitVectorLength( @@ -2350,9 +2337,12 @@ bool VPlanTransforms::tryAddExplicitVectorLength( auto *EVLPhi = new VPEVLBasedIVPHIRecipe(StartV, DebugLoc()); EVLPhi->insertAfter(CanonicalIVPHI); VPBuilder Builder(Header, Header->getFirstNonPhi()); - // Compute original TC - IV as the AVL (application vector length). - VPValue *AVL = Builder.createNaryOp( - Instruction::Sub, {Plan.getTripCount(), EVLPhi}, DebugLoc(), "avl"); + // Create the AVL (application vector length), starting from TC -> 0 in steps + // of EVL. + VPPhi *AVLPhi = Builder.createScalarPhi( + {Plan.getTripCount()}, DebugLoc::getCompilerGenerated(), "avl"); + VPValue *AVL = AVLPhi; + if (MaxSafeElements) { // Support for MaxSafeDist for correct loop emission. VPValue *AVLSafe = @@ -2379,6 +2369,11 @@ bool VPlanTransforms::tryAddExplicitVectorLength( CanonicalIVIncrement->getDebugLoc(), "index.evl.next"); EVLPhi->addOperand(NextEVLIV); + VPValue *NextAVL = Builder.createOverflowingOp( + Instruction::Sub, {AVLPhi, OpVPEVL}, {/*hasNUW=*/true, /*hasNSW=*/false}, + DebugLoc::getCompilerGenerated(), "avl.next"); + AVLPhi->addOperand(NextAVL); + transformRecipestoEVLRecipes(Plan, *VPEVL); // Replace all uses of VPCanonicalIVPHIRecipe by @@ -2391,7 +2386,6 @@ bool VPlanTransforms::tryAddExplicitVectorLength( } void VPlanTransforms::canonicalizeEVLLoops(VPlan &Plan) { - using namespace llvm::VPlanPatternMatch; // Find EVL loop entries by locating VPEVLBasedIVPHIRecipe. // There should be only one EVL PHI in the entire plan. VPEVLBasedIVPHIRecipe *EVLPhi = nullptr; @@ -2480,7 +2474,6 @@ void VPlanTransforms::dropPoisonGeneratingRecipes( // drop them directly. if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) { VPValue *A, *B; - using namespace llvm::VPlanPatternMatch; // Dropping disjoint from an OR may yield incorrect results, as some // analysis may have converted it to an Add implicitly (e.g. SCEV used // for dependence analysis). Instead, replace it with an equivalent Add. @@ -2570,7 +2563,8 @@ void VPlanTransforms::createInterleaveGroups( } bool NeedsMaskForGaps = - IG->requiresScalarEpilogue() && !ScalarEpilogueAllowed; + (IG->requiresScalarEpilogue() && !ScalarEpilogueAllowed) || + (!StoredValues.empty() && !IG->isFull()); Instruction *IRInsertPos = IG->getInsertPos(); auto *InsertPos = @@ -2774,8 +2768,6 @@ void VPlanTransforms::dissolveLoopRegions(VPlan &Plan) { void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan, Type &CanonicalIVTy) { - using namespace llvm::VPlanPatternMatch; - VPTypeAnalysis TypeInfo(&CanonicalIVTy); SmallVector<VPRecipeBase *> ToRemove; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( @@ -2852,8 +2844,6 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan, void VPlanTransforms::handleUncountableEarlyExit( VPBasicBlock *EarlyExitingVPBB, VPBasicBlock *EarlyExitVPBB, VPlan &Plan, VPBasicBlock *HeaderVPBB, VPBasicBlock *LatchVPBB, VFRange &Range) { - using namespace llvm::VPlanPatternMatch; - VPBlockBase *MiddleVPBB = LatchVPBB->getSuccessors()[0]; if (!EarlyExitVPBB->getSinglePredecessor() && EarlyExitVPBB->getPredecessors()[1] == MiddleVPBB) { @@ -2947,8 +2937,6 @@ void VPlanTransforms::handleUncountableEarlyExit( static VPExpressionRecipe * tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VFRange &Range) { - using namespace VPlanPatternMatch; - Type *RedTy = Ctx.Types.inferScalarType(Red); VPValue *VecOp = Red->getVecOp(); @@ -2994,8 +2982,6 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, static VPExpressionRecipe * tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VFRange &Range) { - using namespace VPlanPatternMatch; - unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); if (Opcode != Instruction::Add) return nullptr; @@ -3256,7 +3242,6 @@ static bool isAlreadyNarrow(VPValue *VPV) { void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF, unsigned VectorRegWidth) { - using namespace llvm::VPlanPatternMatch; VPRegionBlock *VectorLoop = Plan.getVectorLoopRegion(); if (VF.isScalable() || !VectorLoop) return; diff --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index 57d01cb..14ae4f2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -79,9 +79,8 @@ bool VPlanVerifier::verifyPhiRecipes(const VPBasicBlock *VPBB) { if (isa<VPActiveLaneMaskPHIRecipe>(RecipeI)) NumActiveLaneMaskPhiRecipes++; - if (IsHeaderVPBB && !isa<VPHeaderPHIRecipe, VPWidenPHIRecipe>(*RecipeI) && - !isa<VPInstruction>(*RecipeI) && - cast<VPInstruction>(RecipeI)->getOpcode() == Instruction::PHI) { + if (IsHeaderVPBB && + !isa<VPHeaderPHIRecipe, VPWidenPHIRecipe, VPPhi>(*RecipeI)) { errs() << "Found non-header PHI recipe in header VPBB"; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) errs() << ": "; @@ -173,7 +172,8 @@ bool VPlanVerifier::verifyEVLRecipe(const VPInstruction &EVL) const { [&](const VPInstructionWithType *S) { return VerifyEVLUse(*S, 0); }) .Case<VPInstruction>([&](const VPInstruction *I) { if (I->getOpcode() == Instruction::PHI || - I->getOpcode() == Instruction::ICmp) + I->getOpcode() == Instruction::ICmp || + I->getOpcode() == Instruction::Sub) return VerifyEVLUse(*I, 1); switch (I->getOpcode()) { case Instruction::Add: diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 6252f4f..6345b18 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1664,6 +1664,8 @@ static Align computeAlignmentAfterScalarization(Align VectorAlignment, // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1 // store i32 %b, i32* %1 bool VectorCombine::foldSingleElementStore(Instruction &I) { + if (!TTI.allowVectorElementIndexingUsingGEP()) + return false; auto *SI = cast<StoreInst>(&I); if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType())) return false; @@ -1719,6 +1721,9 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { /// Try to scalarize vector loads feeding extractelement instructions. bool VectorCombine::scalarizeLoadExtract(Instruction &I) { + if (!TTI.allowVectorElementIndexingUsingGEP()) + return false; + Value *Ptr; if (!match(&I, m_Load(m_Value(Ptr)))) return false; @@ -1827,6 +1832,8 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { } bool VectorCombine::scalarizeExtExtract(Instruction &I) { + if (!TTI.allowVectorElementIndexingUsingGEP()) + return false; auto *Ext = dyn_cast<ZExtInst>(&I); if (!Ext) return false; |