diff options
Diffstat (limited to 'llvm/lib')
192 files changed, 5020 insertions, 2808 deletions
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt index cfde787..16dd6f8 100644 --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -175,6 +175,7 @@ add_llvm_component_library(LLVMAnalysis LINK_COMPONENTS BinaryFormat Core + FrontendHLSL Object ProfileData Support diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 2d52f34..dd98b62 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -2679,11 +2679,12 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, case Intrinsic::nvvm_round_ftz_f: case Intrinsic::nvvm_round_f: case Intrinsic::nvvm_round_d: { - // Use APFloat implementation instead of native libm call, as some - // implementations (e.g. on PPC) do not preserve the sign of negative 0. + // nvvm_round is lowered to PTX cvt.rni, which will round to nearest + // integer, choosing even integer if source is equidistant between two + // integers, so the semantics are closer to "rint" rather than "round". bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID); auto V = IsFTZ ? FTZPreserveSign(APF) : APF; - V.roundToIntegral(APFloat::rmNearestTiesToAway); + V.roundToIntegral(APFloat::rmNearestTiesToEven); return ConstantFP::get(Ty->getContext(), V); } diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp index 1959ab6..629fa7cd 100644 --- a/llvm/lib/Analysis/DXILResource.cpp +++ b/llvm/lib/Analysis/DXILResource.cpp @@ -995,18 +995,7 @@ SmallVector<dxil::ResourceInfo *> DXILResourceMap::findByUse(const Value *Key) { //===----------------------------------------------------------------------===// void DXILResourceBindingInfo::populate(Module &M, DXILResourceTypeMap &DRTM) { - struct Binding { - ResourceClass RC; - uint32_t Space; - uint32_t LowerBound; - uint32_t UpperBound; - Value *Name; - Binding(ResourceClass RC, uint32_t Space, uint32_t LowerBound, - uint32_t UpperBound, Value *Name) - : RC(RC), Space(Space), LowerBound(LowerBound), UpperBound(UpperBound), - Name(Name) {} - }; - SmallVector<Binding> Bindings; + hlsl::BindingInfoBuilder Builder; // collect all of the llvm.dx.resource.handlefrombinding calls; // make a note if there is llvm.dx.resource.handlefromimplicitbinding @@ -1036,133 +1025,20 @@ void DXILResourceBindingInfo::populate(Module &M, DXILResourceTypeMap &DRTM) { assert((Size < 0 || (unsigned)LowerBound + Size - 1 <= UINT32_MAX) && "upper bound register overflow"); uint32_t UpperBound = Size < 0 ? UINT32_MAX : LowerBound + Size - 1; - Bindings.emplace_back(RTI.getResourceClass(), Space, LowerBound, - UpperBound, Name); + Builder.trackBinding(RTI.getResourceClass(), Space, LowerBound, + UpperBound, Name); } break; } case Intrinsic::dx_resource_handlefromimplicitbinding: { - ImplicitBinding = true; + HasImplicitBinding = true; break; } } } - // sort all the collected bindings - llvm::stable_sort(Bindings, [](auto &LHS, auto &RHS) { - return std::tie(LHS.RC, LHS.Space, LHS.LowerBound) < - std::tie(RHS.RC, RHS.Space, RHS.LowerBound); - }); - - // remove duplicates - Binding *NewEnd = llvm::unique(Bindings, [](auto &LHS, auto &RHS) { - return std::tie(LHS.RC, LHS.Space, LHS.LowerBound, LHS.UpperBound, - LHS.Name) == std::tie(RHS.RC, RHS.Space, RHS.LowerBound, - RHS.UpperBound, RHS.Name); - }); - if (NewEnd != Bindings.end()) - Bindings.erase(NewEnd); - - // Go over the sorted bindings and build up lists of free register ranges - // for each binding type and used spaces. Bindings are sorted by resource - // class, space, and lower bound register slot. - BindingSpaces *BS = &SRVSpaces; - for (const Binding &B : Bindings) { - if (BS->RC != B.RC) - // move to the next resource class spaces - BS = &getBindingSpaces(B.RC); - - RegisterSpace *S = BS->Spaces.empty() ? &BS->Spaces.emplace_back(B.Space) - : &BS->Spaces.back(); - assert(S->Space <= B.Space && "bindings not sorted correctly?"); - if (B.Space != S->Space) - // add new space - S = &BS->Spaces.emplace_back(B.Space); - - // The space is full - there are no free slots left, or the rest of the - // slots are taken by an unbounded array. Set flag to report overlapping - // binding later. - if (S->FreeRanges.empty() || S->FreeRanges.back().UpperBound < UINT32_MAX) { - OverlappingBinding = true; - continue; - } - - // adjust the last free range lower bound, split it in two, or remove it - BindingRange &LastFreeRange = S->FreeRanges.back(); - if (LastFreeRange.LowerBound == B.LowerBound) { - if (B.UpperBound < UINT32_MAX) - LastFreeRange.LowerBound = B.UpperBound + 1; - else - S->FreeRanges.pop_back(); - } else if (LastFreeRange.LowerBound < B.LowerBound) { - LastFreeRange.UpperBound = B.LowerBound - 1; - if (B.UpperBound < UINT32_MAX) - S->FreeRanges.emplace_back(B.UpperBound + 1, UINT32_MAX); - } else { - OverlappingBinding = true; - if (B.UpperBound < UINT32_MAX) - LastFreeRange.LowerBound = - std::max(LastFreeRange.LowerBound, B.UpperBound + 1); - else - S->FreeRanges.pop_back(); - } - } -} - -// returns std::nulopt if binding could not be found in given space -std::optional<uint32_t> -DXILResourceBindingInfo::findAvailableBinding(dxil::ResourceClass RC, - uint32_t Space, int32_t Size) { - BindingSpaces &BS = getBindingSpaces(RC); - RegisterSpace &RS = BS.getOrInsertSpace(Space); - return RS.findAvailableBinding(Size); -} - -DXILResourceBindingInfo::RegisterSpace & -DXILResourceBindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) { - for (auto *I = Spaces.begin(); I != Spaces.end(); ++I) { - if (I->Space == Space) - return *I; - if (I->Space < Space) - continue; - return *Spaces.insert(I, Space); - } - return Spaces.emplace_back(Space); -} - -std::optional<uint32_t> -DXILResourceBindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) { - assert((Size == -1 || Size > 0) && "invalid size"); - - if (FreeRanges.empty()) - return std::nullopt; - - // unbounded array - if (Size == -1) { - BindingRange &Last = FreeRanges.back(); - if (Last.UpperBound != UINT32_MAX) - // this space is already occupied by an unbounded array - return std::nullopt; - uint32_t RegSlot = Last.LowerBound; - FreeRanges.pop_back(); - return RegSlot; - } - - // single resource or fixed-size array - for (BindingRange &R : FreeRanges) { - // compare the size as uint64_t to prevent overflow for range (0, - // UINT32_MAX) - if ((uint64_t)R.UpperBound - R.LowerBound + 1 < (uint64_t)Size) - continue; - uint32_t RegSlot = R.LowerBound; - // This might create a range where (LowerBound == UpperBound + 1). When - // that happens, the next time this function is called the range will - // skipped over by the check above (at this point Size is always > 0). - R.LowerBound += Size; - return RegSlot; - } - - return std::nullopt; + Bindings = Builder.calculateBindingInfo( + [this](auto, auto) { this->HasOverlappingBinding = true; }); } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp index 393f264..6fc81d787 100644 --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -342,7 +342,7 @@ bool llvm::isDereferenceableAndAlignedInLoop( : SE.getConstantMaxBackedgeTakenCount(L); } const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess( - L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr); + L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr, &DT, AC); if (isa<SCEVCouldNotCompute>(AccessStart) || isa<SCEVCouldNotCompute>(AccessEnd)) return false; diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 14be385..a553533 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -23,6 +23,8 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AssumeBundleQueries.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" @@ -208,28 +210,46 @@ static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B, /// Return true, if evaluating \p AR at \p MaxBTC cannot wrap, because \p AR at /// \p MaxBTC is guaranteed inbounds of the accessed object. -static bool evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, - const SCEV *MaxBTC, - const SCEV *EltSize, - ScalarEvolution &SE, - const DataLayout &DL) { +static bool +evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, + const SCEV *MaxBTC, const SCEV *EltSize, + ScalarEvolution &SE, const DataLayout &DL, + DominatorTree *DT, AssumptionCache *AC) { auto *PointerBase = SE.getPointerBase(AR->getStart()); auto *StartPtr = dyn_cast<SCEVUnknown>(PointerBase); if (!StartPtr) return false; + const Loop *L = AR->getLoop(); bool CheckForNonNull, CheckForFreed; - uint64_t DerefBytes = StartPtr->getValue()->getPointerDereferenceableBytes( + Value *StartPtrV = StartPtr->getValue(); + uint64_t DerefBytes = StartPtrV->getPointerDereferenceableBytes( DL, CheckForNonNull, CheckForFreed); - if (CheckForNonNull || CheckForFreed) + if (DerefBytes && (CheckForNonNull || CheckForFreed)) return false; const SCEV *Step = AR->getStepRecurrence(SE); + Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType()); + const SCEV *DerefBytesSCEV = SE.getConstant(WiderTy, DerefBytes); + + // Check if we have a suitable dereferencable assumption we can use. + if (!StartPtrV->canBeFreed()) { + RetainedKnowledge DerefRK = getKnowledgeValidInContext( + StartPtrV, {Attribute::Dereferenceable}, *AC, + L->getLoopPredecessor()->getTerminator(), DT); + if (DerefRK) { + DerefBytesSCEV = SE.getUMaxExpr( + DerefBytesSCEV, SE.getConstant(WiderTy, DerefRK.ArgValue)); + } + } + + if (DerefBytesSCEV->isZero()) + return false; + bool IsKnownNonNegative = SE.isKnownNonNegative(Step); if (!IsKnownNonNegative && !SE.isKnownNegative(Step)) return false; - Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType()); Step = SE.getNoopOrSignExtend(Step, WiderTy); MaxBTC = SE.getNoopOrZeroExtend(MaxBTC, WiderTy); @@ -256,8 +276,7 @@ static bool evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, const SCEV *EndBytes = addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE); if (!EndBytes) return false; - return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, - SE.getConstant(WiderTy, DerefBytes)); + return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV); } // For negative steps check if @@ -265,15 +284,15 @@ static bool evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, // * StartOffset <= DerefBytes. assert(SE.isKnownNegative(Step) && "must be known negative"); return SE.isKnownPredicate(CmpInst::ICMP_SGE, StartOffset, OffsetEndBytes) && - SE.isKnownPredicate(CmpInst::ICMP_ULE, StartOffset, - SE.getConstant(WiderTy, DerefBytes)); + SE.isKnownPredicate(CmpInst::ICMP_ULE, StartOffset, DerefBytesSCEV); } std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess( const Loop *Lp, const SCEV *PtrExpr, Type *AccessTy, const SCEV *BTC, const SCEV *MaxBTC, ScalarEvolution *SE, DenseMap<std::pair<const SCEV *, Type *>, - std::pair<const SCEV *, const SCEV *>> *PointerBounds) { + std::pair<const SCEV *, const SCEV *>> *PointerBounds, + DominatorTree *DT, AssumptionCache *AC) { std::pair<const SCEV *, const SCEV *> *PtrBoundsPair; if (PointerBounds) { auto [Iter, Ins] = PointerBounds->insert( @@ -308,8 +327,8 @@ std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess( // sets ScEnd to the maximum unsigned value for the type. Note that LAA // separately checks that accesses cannot not wrap, so unsigned max // represents an upper bound. - if (evaluatePtrAddRecAtMaxBTCWillNotWrap(AR, MaxBTC, EltSizeSCEV, *SE, - DL)) { + if (evaluatePtrAddRecAtMaxBTCWillNotWrap(AR, MaxBTC, EltSizeSCEV, *SE, DL, + DT, AC)) { ScEnd = AR->evaluateAtIteration(MaxBTC, *SE); } else { ScEnd = SE->getAddExpr( @@ -356,9 +375,9 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, bool NeedsFreeze) { const SCEV *SymbolicMaxBTC = PSE.getSymbolicMaxBackedgeTakenCount(); const SCEV *BTC = PSE.getBackedgeTakenCount(); - const auto &[ScStart, ScEnd] = - getStartAndEndForAccess(Lp, PtrExpr, AccessTy, BTC, SymbolicMaxBTC, - PSE.getSE(), &DC.getPointerBounds()); + const auto &[ScStart, ScEnd] = getStartAndEndForAccess( + Lp, PtrExpr, AccessTy, BTC, SymbolicMaxBTC, PSE.getSE(), + &DC.getPointerBounds(), DC.getDT(), DC.getAC()); assert(!isa<SCEVCouldNotCompute>(ScStart) && !isa<SCEVCouldNotCompute>(ScEnd) && "must be able to compute both start and end expressions"); @@ -1961,13 +1980,15 @@ bool MemoryDepChecker::areAccessesCompletelyBeforeOrAfter(const SCEV *Src, const SCEV *BTC = PSE.getBackedgeTakenCount(); const SCEV *SymbolicMaxBTC = PSE.getSymbolicMaxBackedgeTakenCount(); ScalarEvolution &SE = *PSE.getSE(); - const auto &[SrcStart_, SrcEnd_] = getStartAndEndForAccess( - InnermostLoop, Src, SrcTy, BTC, SymbolicMaxBTC, &SE, &PointerBounds); + const auto &[SrcStart_, SrcEnd_] = + getStartAndEndForAccess(InnermostLoop, Src, SrcTy, BTC, SymbolicMaxBTC, + &SE, &PointerBounds, DT, AC); if (isa<SCEVCouldNotCompute>(SrcStart_) || isa<SCEVCouldNotCompute>(SrcEnd_)) return false; - const auto &[SinkStart_, SinkEnd_] = getStartAndEndForAccess( - InnermostLoop, Sink, SinkTy, BTC, SymbolicMaxBTC, &SE, &PointerBounds); + const auto &[SinkStart_, SinkEnd_] = + getStartAndEndForAccess(InnermostLoop, Sink, SinkTy, BTC, SymbolicMaxBTC, + &SE, &PointerBounds, DT, AC); if (isa<SCEVCouldNotCompute>(SinkStart_) || isa<SCEVCouldNotCompute>(SinkEnd_)) return false; @@ -3002,7 +3023,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, const TargetTransformInfo *TTI, const TargetLibraryInfo *TLI, AAResults *AA, DominatorTree *DT, LoopInfo *LI, - bool AllowPartial) + AssumptionCache *AC, bool AllowPartial) : PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)), PtrRtChecking(nullptr), TheLoop(L), AllowPartial(AllowPartial) { unsigned MaxTargetVectorWidthInBits = std::numeric_limits<unsigned>::max(); @@ -3012,8 +3033,8 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, MaxTargetVectorWidthInBits = TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) * 2; - DepChecker = std::make_unique<MemoryDepChecker>(*PSE, L, SymbolicStrides, - MaxTargetVectorWidthInBits); + DepChecker = std::make_unique<MemoryDepChecker>( + *PSE, AC, DT, L, SymbolicStrides, MaxTargetVectorWidthInBits); PtrRtChecking = std::make_unique<RuntimePointerChecking>(*DepChecker, SE); if (canAnalyzeLoop()) CanVecMem = analyzeLoop(AA, LI, TLI, DT); @@ -3082,7 +3103,7 @@ const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L, // or if it was created with a different value of AllowPartial. if (Inserted || It->second->hasAllowPartial() != AllowPartial) It->second = std::make_unique<LoopAccessInfo>(&L, &SE, TTI, TLI, &AA, &DT, - &LI, AllowPartial); + &LI, AC, AllowPartial); return *It->second; } @@ -3125,7 +3146,8 @@ LoopAccessInfoManager LoopAccessAnalysis::run(Function &F, auto &LI = FAM.getResult<LoopAnalysis>(F); auto &TTI = FAM.getResult<TargetIRAnalysis>(F); auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); - return LoopAccessInfoManager(SE, AA, DT, LI, &TTI, &TLI); + auto &AC = FAM.getResult<AssumptionAnalysis>(F); + return LoopAccessInfoManager(SE, AA, DT, LI, &TTI, &TLI, &AC); } AnalysisKey LoopAccessAnalysis::Key; diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 0990a0d..477e477 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -2682,6 +2682,20 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, return getAddExpr(NewOps, PreservedFlags); } } + + // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext + // (B), if trunc (A) + -A + B does not unsigned-wrap. + const SCEVAddExpr *InnerAdd; + if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) { + const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType()); + if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) && + getZeroExtendExpr(NarrowA, B->getType()) == A && + hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd}, + SCEV::FlagAnyWrap), + SCEV::FlagNUW)) { + return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType()); + } + } } // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y) diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 55ba52a..c7eb2ec 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1486,6 +1486,10 @@ void TargetTransformInfo::collectKernelLaunchBounds( return TTIImpl->collectKernelLaunchBounds(F, LB); } +bool TargetTransformInfo::allowVectorElementIndexingUsingGEP() const { + return TTIImpl->allowVectorElementIndexingUsingGEP(); +} + TargetTransformInfoImplBase::~TargetTransformInfoImplBase() = default; TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index e9cf2ee..b3b4c37 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -81,7 +81,6 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::exp: case Intrinsic::exp10: case Intrinsic::exp2: - case Intrinsic::ldexp: case Intrinsic::log: case Intrinsic::log10: case Intrinsic::log2: @@ -109,8 +108,6 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::canonicalize: case Intrinsic::fptosi_sat: case Intrinsic::fptoui_sat: - case Intrinsic::lround: - case Intrinsic::llround: case Intrinsic::lrint: case Intrinsic::llrint: case Intrinsic::ucmp: @@ -192,8 +189,6 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg( switch (ID) { case Intrinsic::fptosi_sat: case Intrinsic::fptoui_sat: - case Intrinsic::lround: - case Intrinsic::llround: case Intrinsic::lrint: case Intrinsic::llrint: case Intrinsic::vp_lrint: @@ -208,7 +203,6 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg( case Intrinsic::vp_is_fpclass: return OpdIdx == 0; case Intrinsic::powi: - case Intrinsic::ldexp: return OpdIdx == -1 || OpdIdx == 1; default: return OpdIdx == -1; @@ -1123,7 +1117,7 @@ Constant * llvm::createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, const InterleaveGroup<Instruction> &Group) { // All 1's means mask is not needed. - if (Group.getNumMembers() == Group.getFactor()) + if (Group.isFull()) return nullptr; // TODO: support reversed access. @@ -1669,7 +1663,7 @@ void InterleavedAccessInfo::analyzeInterleaving( // Case 1: A full group. Can Skip the checks; For full groups, if the wide // load would wrap around the address space we would do a memory access at // nullptr even without the transformation. - if (Group->getNumMembers() == Group->getFactor()) + if (Group->isFull()) continue; // Case 2: If first and last members of the group don't wrap this implies @@ -1704,7 +1698,7 @@ void InterleavedAccessInfo::analyzeInterleaving( // Case 1: A full group. Can Skip the checks; For full groups, if the wide // store would wrap around the address space we would do a memory access at // nullptr even without the transformation. - if (Group->getNumMembers() == Group->getFactor()) + if (Group->isFull()) continue; // Interleave-store-group with gaps is implemented using masked wide store. diff --git a/llvm/lib/CGData/StableFunctionMapRecord.cpp b/llvm/lib/CGData/StableFunctionMapRecord.cpp index 4e4fcef..423e068 100644 --- a/llvm/lib/CGData/StableFunctionMapRecord.cpp +++ b/llvm/lib/CGData/StableFunctionMapRecord.cpp @@ -160,14 +160,18 @@ void StableFunctionMapRecord::deserialize(const unsigned char *&Ptr, for (unsigned I = 0; I < NumFuncs; ++I) { auto Hash = endian::readNext<stable_hash, endianness::little, unaligned>(Ptr); - auto FunctionNameId = + [[maybe_unused]] auto FunctionNameId = endian::readNext<uint32_t, endianness::little, unaligned>(Ptr); - assert(FunctionMap->getNameForId(FunctionNameId) && - "FunctionNameId out of range"); - auto ModuleNameId = + [[maybe_unused]] auto ModuleNameId = endian::readNext<uint32_t, endianness::little, unaligned>(Ptr); - assert(FunctionMap->getNameForId(ModuleNameId) && - "ModuleNameId out of range"); + // Only validate IDs if we've read the names + if (ReadStableFunctionMapNames) { + assert(FunctionMap->getNameForId(FunctionNameId) && + "FunctionNameId out of range"); + assert(FunctionMap->getNameForId(ModuleNameId) && + "ModuleNameId out of range"); + } + auto InstCount = endian::readNext<uint32_t, endianness::little, unaligned>(Ptr); diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index 6166271..1641c3e 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -1654,6 +1654,88 @@ void AsmPrinter::emitStackUsage(const MachineFunction &MF) { *StackUsageStream << "static\n"; } +/// Extracts a generalized numeric type identifier of a Function's type from +/// type metadata. Returns null if metadata cannot be found. +static ConstantInt *extractNumericCGTypeId(const Function &F) { + SmallVector<MDNode *, 2> Types; + F.getMetadata(LLVMContext::MD_type, Types); + for (const auto &Type : Types) { + if (Type->hasGeneralizedMDString()) { + MDString *MDGeneralizedTypeId = cast<MDString>(Type->getOperand(1)); + uint64_t TypeIdVal = llvm::MD5Hash(MDGeneralizedTypeId->getString()); + IntegerType *Int64Ty = Type::getInt64Ty(F.getContext()); + return ConstantInt::get(Int64Ty, TypeIdVal); + } + } + return nullptr; +} + +/// Emits .callgraph section. +void AsmPrinter::emitCallGraphSection(const MachineFunction &MF, + FunctionInfo &FuncInfo) { + if (!MF.getTarget().Options.EmitCallGraphSection) + return; + + // Switch to the call graph section for the function + MCSection *FuncCGSection = + getObjFileLowering().getCallGraphSection(*getCurrentSection()); + assert(FuncCGSection && "null callgraph section"); + OutStreamer->pushSection(); + OutStreamer->switchSection(FuncCGSection); + + // Emit format version number. + OutStreamer->emitInt64(CallGraphSectionFormatVersion::V_0); + + // Emit function's self information, which is composed of: + // 1) FunctionEntryPc + // 2) FunctionKind: Whether the function is indirect target, and if so, + // whether its type id is known. + // 3) FunctionTypeId: Emit only when the function is an indirect target + // and its type id is known. + + // Emit function entry pc. + const MCSymbol *FunctionSymbol = getFunctionBegin(); + OutStreamer->emitSymbolValue(FunctionSymbol, TM.getProgramPointerSize()); + + // If this function has external linkage or has its address taken and + // it is not a callback, then anything could call it. + const Function &F = MF.getFunction(); + bool IsIndirectTarget = + !F.hasLocalLinkage() || F.hasAddressTaken(nullptr, + /*IgnoreCallbackUses=*/true, + /*IgnoreAssumeLikeCalls=*/true, + /*IgnoreLLVMUsed=*/false); + + // FIXME: FunctionKind takes a few values but emitted as a 64-bit value. + // Can be optimized to occupy 2 bits instead. + // Emit function kind, and type id if available. + if (!IsIndirectTarget) { + OutStreamer->emitInt64( + static_cast<uint64_t>(FunctionInfo::FunctionKind::NOT_INDIRECT_TARGET)); + } else { + if (const auto *TypeId = extractNumericCGTypeId(F)) { + OutStreamer->emitInt64(static_cast<uint64_t>( + FunctionInfo::FunctionKind::INDIRECT_TARGET_KNOWN_TID)); + OutStreamer->emitInt64(TypeId->getZExtValue()); + } else { + OutStreamer->emitInt64(static_cast<uint64_t>( + FunctionInfo::FunctionKind::INDIRECT_TARGET_UNKNOWN_TID)); + } + } + + // Emit callsite labels, where each element is a pair of type id and + // indirect callsite pc. + const auto &CallSiteLabels = FuncInfo.CallSiteLabels; + OutStreamer->emitInt64(CallSiteLabels.size()); + for (const auto &[TypeId, Label] : CallSiteLabels) { + OutStreamer->emitInt64(TypeId); + OutStreamer->emitSymbolValue(Label, TM.getProgramPointerSize()); + } + FuncInfo.CallSiteLabels.clear(); + + OutStreamer->popSection(); +} + void AsmPrinter::emitPCSectionsLabel(const MachineFunction &MF, const MDNode &MD) { MCSymbol *S = MF.getContext().createTempSymbol("pcsection"); @@ -1784,6 +1866,23 @@ static StringRef getMIMnemonic(const MachineInstr &MI, MCStreamer &Streamer) { return Name; } +void AsmPrinter::emitIndirectCalleeLabels( + FunctionInfo &FuncInfo, + const MachineFunction::CallSiteInfoMap &CallSitesInfoMap, + const MachineInstr &MI) { + // Only indirect calls have type identifiers set. + const auto &CallSiteInfo = CallSitesInfoMap.find(&MI); + if (CallSiteInfo == CallSitesInfoMap.end()) + return; + + for (ConstantInt *CalleeTypeId : CallSiteInfo->second.CalleeTypeIds) { + MCSymbol *S = MF->getContext().createTempSymbol(); + OutStreamer->emitLabel(S); + uint64_t CalleeTypeIdVal = CalleeTypeId->getZExtValue(); + FuncInfo.CallSiteLabels.emplace_back(CalleeTypeIdVal, S); + } +} + /// EmitFunctionBody - This method emits the body and trailer for a /// function. void AsmPrinter::emitFunctionBody() { @@ -1830,6 +1929,8 @@ void AsmPrinter::emitFunctionBody() { MBBSectionRanges[MF->front().getSectionID()] = MBBSectionRange{CurrentFnBegin, nullptr}; + FunctionInfo FuncInfo; + const auto &CallSitesInfoMap = MF->getCallSitesInfo(); for (auto &MBB : *MF) { // Print a label for the basic block. emitBasicBlockStart(MBB); @@ -1963,6 +2064,9 @@ void AsmPrinter::emitFunctionBody() { break; } + if (TM.Options.EmitCallGraphSection && MI.isCall()) + emitIndirectCalleeLabels(FuncInfo, CallSitesInfoMap, MI); + // If there is a post-instruction symbol, emit a label for it here. if (MCSymbol *S = MI.getPostInstrSymbol()) OutStreamer->emitLabel(S); @@ -2142,6 +2246,9 @@ void AsmPrinter::emitFunctionBody() { // Emit section containing stack size metadata. emitStackSizeSection(*MF); + // Emit section containing call graph metadata. + emitCallGraphSection(*MF, FuncInfo); + // Emit .su file containing function stack size information. emitStackUsage(*MF); @@ -2841,6 +2948,7 @@ void AsmPrinter::SetupMachineFunction(MachineFunction &MF) { F.hasFnAttribute("xray-instruction-threshold") || needFuncLabels(MF, *this) || NeedsLocalForSize || MF.getTarget().Options.EmitStackSizeSection || + MF.getTarget().Options.EmitCallGraphSection || MF.getTarget().Options.BBAddrMap) { CurrentFnBegin = createTempSymbol("func_begin"); if (NeedsLocalForSize) diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 416c56d..f16283b 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -2769,6 +2769,29 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) { return optimizeGatherScatterInst(II, II->getArgOperand(0)); case Intrinsic::masked_scatter: return optimizeGatherScatterInst(II, II->getArgOperand(1)); + case Intrinsic::masked_load: + // Treat v1X masked load as load X type. + if (auto *VT = dyn_cast<FixedVectorType>(II->getType())) { + if (VT->getNumElements() == 1) { + Value *PtrVal = II->getArgOperand(0); + unsigned AS = PtrVal->getType()->getPointerAddressSpace(); + if (optimizeMemoryInst(II, PtrVal, VT->getElementType(), AS)) + return true; + } + } + return false; + case Intrinsic::masked_store: + // Treat v1X masked store as store X type. + if (auto *VT = + dyn_cast<FixedVectorType>(II->getArgOperand(0)->getType())) { + if (VT->getNumElements() == 1) { + Value *PtrVal = II->getArgOperand(1); + unsigned AS = PtrVal->getType()->getPointerAddressSpace(); + if (optimizeMemoryInst(II, PtrVal, VT->getElementType(), AS)) + return true; + } + } + return false; } SmallVector<Value *, 2> PtrOps; diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp index 60d42e0..ec40f6a 100644 --- a/llvm/lib/CodeGen/MachineFunction.cpp +++ b/llvm/lib/CodeGen/MachineFunction.cpp @@ -698,6 +698,26 @@ bool MachineFunction::needsFrameMoves() const { !F.getParent()->debug_compile_units().empty(); } +MachineFunction::CallSiteInfo::CallSiteInfo(const CallBase &CB) { + // Numeric callee_type ids are only for indirect calls. + if (!CB.isIndirectCall()) + return; + + MDNode *CalleeTypeList = CB.getMetadata(LLVMContext::MD_callee_type); + if (!CalleeTypeList) + return; + + for (const MDOperand &Op : CalleeTypeList->operands()) { + MDNode *TypeMD = cast<MDNode>(Op); + MDString *TypeIdStr = cast<MDString>(TypeMD->getOperand(1)); + // Compute numeric type id from generalized type id string + uint64_t TypeIdVal = MD5Hash(TypeIdStr->getString()); + IntegerType *Int64Ty = Type::getInt64Ty(CB.getContext()); + CalleeTypeIds.push_back( + ConstantInt::get(Int64Ty, TypeIdVal, /*IsSigned=*/false)); + } +} + namespace llvm { template<> diff --git a/llvm/lib/CodeGen/MachineScheduler.cpp b/llvm/lib/CodeGen/MachineScheduler.cpp index 9d5c39c..c6fa8f4 100644 --- a/llvm/lib/CodeGen/MachineScheduler.cpp +++ b/llvm/lib/CodeGen/MachineScheduler.cpp @@ -3676,8 +3676,8 @@ void GenericScheduler::initialize(ScheduleDAGMI *dag) { TopCand.SU = nullptr; BotCand.SU = nullptr; - TopCluster = nullptr; - BotCluster = nullptr; + TopClusterID = InvalidClusterId; + BotClusterID = InvalidClusterId; } /// Initialize the per-region scheduling policy. @@ -3988,10 +3988,14 @@ bool GenericScheduler::tryCandidate(SchedCandidate &Cand, // This is a best effort to set things up for a post-RA pass. Optimizations // like generating loads of multiple registers should ideally be done within // the scheduler pass by combining the loads during DAG postprocessing. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; @@ -4251,24 +4255,30 @@ void GenericScheduler::reschedulePhysReg(SUnit *SU, bool isTop) { void GenericScheduler::schedNode(SUnit *SU, bool IsTopNode) { if (IsTopNode) { SU->TopReadyCycle = std::max(SU->TopReadyCycle, Top.getCurrCycle()); - TopCluster = DAG->getCluster(SU->ParentClusterIdx); - LLVM_DEBUG(if (TopCluster) { - dbgs() << " Top Cluster: "; - for (auto *N : *TopCluster) - dbgs() << N->NodeNum << '\t'; - dbgs() << '\n'; + TopClusterID = SU->ParentClusterIdx; + LLVM_DEBUG({ + if (TopClusterID != InvalidClusterId) { + ClusterInfo *TopCluster = DAG->getCluster(TopClusterID); + dbgs() << " Top Cluster: "; + for (auto *N : *TopCluster) + dbgs() << N->NodeNum << '\t'; + dbgs() << '\n'; + } }); Top.bumpNode(SU); if (SU->hasPhysRegUses) reschedulePhysReg(SU, true); } else { SU->BotReadyCycle = std::max(SU->BotReadyCycle, Bot.getCurrCycle()); - BotCluster = DAG->getCluster(SU->ParentClusterIdx); - LLVM_DEBUG(if (BotCluster) { - dbgs() << " Bot Cluster: "; - for (auto *N : *BotCluster) - dbgs() << N->NodeNum << '\t'; - dbgs() << '\n'; + BotClusterID = SU->ParentClusterIdx; + LLVM_DEBUG({ + if (BotClusterID != InvalidClusterId) { + ClusterInfo *BotCluster = DAG->getCluster(BotClusterID); + dbgs() << " Bot Cluster: "; + for (auto *N : *BotCluster) + dbgs() << N->NodeNum << '\t'; + dbgs() << '\n'; + } }); Bot.bumpNode(SU); if (SU->hasPhysRegDefs) @@ -4306,8 +4316,8 @@ void PostGenericScheduler::initialize(ScheduleDAGMI *Dag) { if (!Bot.HazardRec) { Bot.HazardRec = DAG->TII->CreateTargetMIHazardRecognizer(Itin, DAG); } - TopCluster = nullptr; - BotCluster = nullptr; + TopClusterID = InvalidClusterId; + BotClusterID = InvalidClusterId; } void PostGenericScheduler::initPolicy(MachineBasicBlock::iterator Begin, @@ -4373,10 +4383,14 @@ bool PostGenericScheduler::tryCandidate(SchedCandidate &Cand, return TryCand.Reason != NoCand; // Keep clustered nodes together. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; // Avoid critical resource consumption and balance the schedule. @@ -4575,11 +4589,11 @@ SUnit *PostGenericScheduler::pickNode(bool &IsTopNode) { void PostGenericScheduler::schedNode(SUnit *SU, bool IsTopNode) { if (IsTopNode) { SU->TopReadyCycle = std::max(SU->TopReadyCycle, Top.getCurrCycle()); - TopCluster = DAG->getCluster(SU->ParentClusterIdx); + TopClusterID = SU->ParentClusterIdx; Top.bumpNode(SU); } else { SU->BotReadyCycle = std::max(SU->BotReadyCycle, Bot.getCurrCycle()); - BotCluster = DAG->getCluster(SU->ParentClusterIdx); + BotClusterID = SU->ParentClusterIdx; Bot.bumpNode(SU); } } diff --git a/llvm/lib/CodeGen/ModuloSchedule.cpp b/llvm/lib/CodeGen/ModuloSchedule.cpp index 0f742c4..21bf052 100644 --- a/llvm/lib/CodeGen/ModuloSchedule.cpp +++ b/llvm/lib/CodeGen/ModuloSchedule.cpp @@ -423,7 +423,7 @@ void ModuloScheduleExpander::generateExistingPhis( // potentially define two values. unsigned MaxPhis = PrologStage + 2; if (!InKernel && (int)PrologStage <= LoopValStage) - MaxPhis = std::max((int)MaxPhis - (int)LoopValStage, 1); + MaxPhis = std::max((int)MaxPhis - LoopValStage, 1); unsigned NumPhis = std::min(NumStages, MaxPhis); Register NewReg; diff --git a/llvm/lib/CodeGen/RegAllocBase.cpp b/llvm/lib/CodeGen/RegAllocBase.cpp index 69b9291..2400a1f 100644 --- a/llvm/lib/CodeGen/RegAllocBase.cpp +++ b/llvm/lib/CodeGen/RegAllocBase.cpp @@ -178,10 +178,8 @@ void RegAllocBase::cleanupFailedVReg(Register FailedReg, MCRegister PhysReg, for (MCRegAliasIterator Aliases(PhysReg, TRI, true); Aliases.isValid(); ++Aliases) { for (MachineOperand &MO : MRI->reg_operands(*Aliases)) { - if (MO.readsReg()) { + if (MO.readsReg()) MO.setIsUndef(true); - LIS->removeAllRegUnitsForPhysReg(MO.getReg()); - } } } } diff --git a/llvm/lib/CodeGen/RegisterCoalescer.cpp b/llvm/lib/CodeGen/RegisterCoalescer.cpp index 2d7987a..7ede564 100644 --- a/llvm/lib/CodeGen/RegisterCoalescer.cpp +++ b/llvm/lib/CodeGen/RegisterCoalescer.cpp @@ -306,7 +306,12 @@ class RegisterCoalescer : private LiveRangeEdit::Delegate { /// number if it is not zero. If DstReg is a physical register and the /// existing subregister number of the def / use being updated is not zero, /// make sure to set it to the correct physical subregister. - void updateRegDefsUses(Register SrcReg, Register DstReg, unsigned SubIdx); + /// + /// If \p SubregToRegSrcInst is not empty, we are coalescing a + /// `DstReg = SUBREG_TO_REG SrcReg`, which should introduce an + /// implicit-def of DstReg on instructions that define SrcReg. + void updateRegDefsUses(Register SrcReg, Register DstReg, unsigned SubIdx, + ArrayRef<MachineInstr *> SubregToRegSrcInst = {}); /// If the given machine operand reads only undefined lanes add an undef /// flag. @@ -1443,6 +1448,7 @@ bool RegisterCoalescer::reMaterializeTrivialDef(const CoalescerPair &CP, // CopyMI may have implicit operands, save them so that we can transfer them // over to the newly materialized instruction after CopyMI is removed. + LaneBitmask NewMIImplicitOpsMask; SmallVector<MachineOperand, 4> ImplicitOps; ImplicitOps.reserve(CopyMI->getNumOperands() - CopyMI->getDesc().getNumOperands()); @@ -1457,6 +1463,9 @@ bool RegisterCoalescer::reMaterializeTrivialDef(const CoalescerPair &CP, (MO.getSubReg() == 0 && MO.getReg() == DstOperand.getReg())) && "unexpected implicit virtual register def"); ImplicitOps.push_back(MO); + if (MO.isDef() && MO.getReg().isVirtual() && + MRI->shouldTrackSubRegLiveness(DstReg)) + NewMIImplicitOpsMask |= MRI->getMaxLaneMaskForVReg(MO.getReg()); } } @@ -1499,14 +1508,11 @@ bool RegisterCoalescer::reMaterializeTrivialDef(const CoalescerPair &CP, } else { assert(MO.getReg() == NewMI.getOperand(0).getReg()); - // We're only expecting another def of the main output, so the range - // should get updated with the regular output range. - // - // FIXME: The range updating below probably needs updating to look at - // the super register if subranges are tracked. - assert(!MRI->shouldTrackSubRegLiveness(DstReg) && - "subrange update for implicit-def of super register may not be " - "properly handled"); + // If lanemasks need to be tracked, compile the lanemask of the NewMI + // implicit def operands to avoid subranges for the super-regs from + // being removed by code later on in this function. + if (MRI->shouldTrackSubRegLiveness(MO.getReg())) + NewMIImplicitOpsMask |= MRI->getMaxLaneMaskForVReg(MO.getReg()); } } } @@ -1606,7 +1612,8 @@ bool RegisterCoalescer::reMaterializeTrivialDef(const CoalescerPair &CP, CurrIdx.getRegSlot(NewMI.getOperand(0).isEarlyClobber()); VNInfo::Allocator &Alloc = LIS->getVNInfoAllocator(); for (LiveInterval::SubRange &SR : DstInt.subranges()) { - if ((SR.LaneMask & DstMask).none()) { + if ((SR.LaneMask & DstMask).none() && + (SR.LaneMask & NewMIImplicitOpsMask).none()) { LLVM_DEBUG(dbgs() << "Removing undefined SubRange " << PrintLaneMask(SR.LaneMask) << " : " << SR << "\n"); @@ -1870,11 +1877,14 @@ void RegisterCoalescer::addUndefFlag(const LiveInterval &Int, SlotIndex UseIdx, } } -void RegisterCoalescer::updateRegDefsUses(Register SrcReg, Register DstReg, - unsigned SubIdx) { +void RegisterCoalescer::updateRegDefsUses( + Register SrcReg, Register DstReg, unsigned SubIdx, + ArrayRef<MachineInstr *> SubregToRegSrcInsts) { bool DstIsPhys = DstReg.isPhysical(); LiveInterval *DstInt = DstIsPhys ? nullptr : &LIS->getInterval(DstReg); + // Coalescing a COPY may expose reads of 'undef' subregisters. + // If so, then explicitly propagate 'undef' to those operands. if (DstInt && DstInt->hasSubRanges() && DstReg != SrcReg) { for (MachineOperand &MO : MRI->reg_operands(DstReg)) { if (MO.isUndef()) @@ -1891,6 +1901,15 @@ void RegisterCoalescer::updateRegDefsUses(Register SrcReg, Register DstReg, } } + // If DstInt already has a subrange for the unused lanes, then we shouldn't + // create duplicate subranges when we update the interval for unused lanes. + LaneBitmask DstIntLaneMask; + if (DstInt && MRI->shouldTrackSubRegLiveness(DstReg)) { + for (LiveInterval::SubRange &SR : DstInt->subranges()) + DstIntLaneMask |= SR.LaneMask; + } + + // Go through all instructions to replace uses of 'SrcReg' by 'DstReg'. SmallPtrSet<MachineInstr *, 8> Visited; for (MachineRegisterInfo::reg_instr_iterator I = MRI->reg_instr_begin(SrcReg), E = MRI->reg_instr_end(); @@ -1914,6 +1933,80 @@ void RegisterCoalescer::updateRegDefsUses(Register SrcReg, Register DstReg, if (DstInt && !Reads && SubIdx && !UseMI->isDebugInstr()) Reads = DstInt->liveAt(LIS->getInstructionIndex(*UseMI)); + bool RequiresImplicitRedef = false; + if (!SubregToRegSrcInsts.empty()) { + // We can only add an implicit-def and undef if the sub registers match, + // e.g. + // %0:gr32 = INSTX + // %0.sub8:gr32 = INSTY // top 24 bits of %0 still defined + // %1:gr64 = SUBREG_TO_REG 0, %0, %subreg.sub32 + // + // This cannot be transformed into: + // %1.sub32:gr64 = INSTX + // undef %1.sub8:gr64 = INSTY , implicit-def %1 + // + // Because that would thrash the top 24 bits of %1.sub32. + if (is_contained(SubregToRegSrcInsts, UseMI) && + all_of(UseMI->defs(), + [&SubIdx, &SrcReg](const MachineOperand &MO) -> bool { + if (MO.getReg() != SrcReg || !MO.getSubReg() || MO.isUndef()) + return true; + return SubIdx == MO.getSubReg(); + })) { + // Add implicit-def of super-register to express that the whole + // register is defined by the instruction. + MachineInstrBuilder MIB(*MF, UseMI); + MIB.addReg(DstReg, RegState::ImplicitDefine); + RequiresImplicitRedef = true; + } + + // If the coalesed instruction doesn't fully define the register, we need + // to preserve the original super register liveness for SUBREG_TO_REG. + // + // We pretended SUBREG_TO_REG was a regular copy for coalescing purposes, + // but it introduces liveness for other subregisters. Downstream users may + // have been relying on those bits, so we need to ensure their liveness is + // captured with a def of other lanes. + if (DstInt && MRI->shouldTrackSubRegLiveness(DstReg)) { + // First check if there is sufficient granularity in terms of subranges. + LaneBitmask DstMask = MRI->getMaxLaneMaskForVReg(DstInt->reg()); + LaneBitmask UsedLanes = TRI->getSubRegIndexLaneMask(SubIdx); + LaneBitmask UnusedLanes = DstMask & ~UsedLanes; + if ((UnusedLanes & ~DstIntLaneMask).any()) { + BumpPtrAllocator &Allocator = LIS->getVNInfoAllocator(); + DstInt->createSubRangeFrom(Allocator, UnusedLanes, *DstInt); + DstIntLaneMask |= UnusedLanes; + } + + // After duplicating the live ranges for the low/hi bits, we + // need to update the subranges of the DstReg interval such that + // for a case like this: + // + // entry: + // 16B %1:gpr32 = INSTRUCTION (<=> UseMI) + // : + // if.then: + // 32B %1:gpr32 = MOVIMM32 .. + // 48B %0:gpr64 = SUBREG_TO_REG 0, %1, sub32 + // + // Only the MOVIMM32 require a def of the top lanes and any intervals + // for the top 32-bits of the def at 16B should be removed. + for (LiveInterval::SubRange &SR : DstInt->subranges()) { + if (!Writes || RequiresImplicitRedef || + (SR.LaneMask & UnusedLanes).none()) + continue; + + assert((SR.LaneMask & UnusedLanes) == SR.LaneMask && + "Unexpected lanemask. Subrange needs finer granularity"); + + SlotIndex UseIdx = LIS->getInstructionIndex(*UseMI).getRegSlot(false); + auto SegmentI = SR.find(UseIdx); + if (SegmentI != SR.end()) + SR.removeSegment(SegmentI, true); + } + } + } + // Replace SrcReg with DstReg in all UseMI operands. for (unsigned Op : Ops) { MachineOperand &MO = UseMI->getOperand(Op); @@ -1922,7 +2015,7 @@ void RegisterCoalescer::updateRegDefsUses(Register SrcReg, Register DstReg, // turn a full def into a read-modify-write sub-register def and vice // versa. if (SubIdx && MO.isDef()) - MO.setIsUndef(!Reads); + MO.setIsUndef(!Reads || RequiresImplicitRedef); // A subreg use of a partially undef (super) register may be a complete // undef use now and then has to be marked that way. @@ -2025,6 +2118,30 @@ void RegisterCoalescer::setUndefOnPrunedSubRegUses(LiveInterval &LI, LIS->shrinkToUses(&LI); } +/// For a given use of value \p Idx, it returns the def in the current block, +/// or otherwise all possible defs in preceding blocks. +static bool FindDefInBlock(SmallPtrSetImpl<MachineBasicBlock *> &VisitedBlocks, + SmallVector<MachineInstr *> &Instrs, + LiveIntervals *LIS, LiveInterval &SrcInt, + MachineBasicBlock *MBB, VNInfo *Idx) { + if (!Idx->isPHIDef()) { + MachineInstr *Def = LIS->getInstructionFromIndex(Idx->def); + assert(Def && "Unable to find a def for SUBREG_TO_REG source operand"); + Instrs.push_back(Def); + return true; + } + + bool Any = false; + if (VisitedBlocks.count(MBB)) + return false; + VisitedBlocks.insert(MBB); + for (MachineBasicBlock *Pred : MBB->predecessors()) { + Any |= FindDefInBlock(VisitedBlocks, Instrs, LIS, SrcInt, Pred, + SrcInt.getVNInfoBefore(LIS->getMBBEndIdx(Pred))); + } + return Any; +} + bool RegisterCoalescer::joinCopy( MachineInstr *CopyMI, bool &Again, SmallPtrSetImpl<MachineInstr *> &CurrentErasedInstrs) { @@ -2156,6 +2273,35 @@ bool RegisterCoalescer::joinCopy( }); } + SmallVector<MachineInstr *> SubregToRegSrcInsts; + if (CopyMI->isSubregToReg()) { + // For the case where the copy instruction is a SUBREG_TO_REG, e.g. + // + // %0:gpr32 = movimm32 .. + // %1:gpr64 = SUBREG_TO_REG 0, %0, sub32 + // ... + // %0:gpr32 = COPY <something> + // + // After joining liveranges, the original `movimm32` will need an + // implicit-def to make it explicit that the entire register is written, + // i.e. + // + // undef %0.sub32:gpr64 = movimm32 ..., implicit-def %0 + // ... + // undef %0.sub32:gpr64 = COPY <something> // Note that this does not + // // require an implicit-def, + // // because it has nothing to + // // do with the SUBREG_TO_REG. + LiveInterval &SrcInt = + LIS->getInterval(CP.isFlipped() ? CP.getDstReg() : CP.getSrcReg()); + SlotIndex SubregToRegSlotIdx = LIS->getInstructionIndex(*CopyMI); + SmallPtrSet<MachineBasicBlock *, 8> VisitedBlocks; + if (!FindDefInBlock(VisitedBlocks, SubregToRegSrcInsts, LIS, SrcInt, + CopyMI->getParent(), + SrcInt.Query(SubregToRegSlotIdx).valueIn())) + llvm_unreachable("SUBREG_TO_REG src requires a def"); + } + ShrinkMask = LaneBitmask::getNone(); ShrinkMainRange = false; @@ -2225,9 +2371,12 @@ bool RegisterCoalescer::joinCopy( // Rewrite all SrcReg operands to DstReg. // Also update DstReg operands to include DstIdx if it is set. - if (CP.getDstIdx()) + if (CP.getDstIdx()) { + assert(SubregToRegSrcInsts.empty() && "can this happen?"); updateRegDefsUses(CP.getDstReg(), CP.getDstReg(), CP.getDstIdx()); - updateRegDefsUses(CP.getSrcReg(), CP.getDstReg(), CP.getSrcIdx()); + } + updateRegDefsUses(CP.getSrcReg(), CP.getDstReg(), CP.getSrcIdx(), + SubregToRegSrcInsts); // Shrink subregister ranges if necessary. if (ShrinkMask.any()) { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index a43020e..11e869a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -331,6 +331,11 @@ namespace { return CombineTo(N, To, 2, AddTo); } + SDValue CombineTo(SDNode *N, SmallVectorImpl<SDValue> *To, + bool AddTo = true) { + return CombineTo(N, To->data(), To->size(), AddTo); + } + void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO); private: @@ -541,6 +546,7 @@ namespace { SDValue visitEXTRACT_VECTOR_ELT(SDNode *N); SDValue visitBUILD_VECTOR(SDNode *N); SDValue visitCONCAT_VECTORS(SDNode *N); + SDValue visitVECTOR_INTERLEAVE(SDNode *N); SDValue visitEXTRACT_SUBVECTOR(SDNode *N); SDValue visitVECTOR_SHUFFLE(SDNode *N); SDValue visitSCALAR_TO_VECTOR(SDNode *N); @@ -2021,6 +2027,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N); case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N); case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N); + case ISD::VECTOR_INTERLEAVE: return visitVECTOR_INTERLEAVE(N); case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N); case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N); case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N); @@ -4100,18 +4107,17 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // (sub x, ([v]select (uge x, y), y, 0)) -> (umin x, (sub x, y)) if (N1.hasOneUse() && hasUMin(VT)) { SDValue Y; - if (sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETULT)), - m_Zero(), m_Deferred(Y))) || - sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETUGE)), - m_Deferred(Y), m_Zero())) || - sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETULT)), - m_Zero(), m_Deferred(Y))) || - sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETUGE)), - m_Deferred(Y), m_Zero()))) + auto MS0 = m_Specific(N0); + auto MVY = m_Value(Y); + auto MZ = m_Zero(); + auto MCC1 = m_SpecificCondCode(ISD::SETULT); + auto MCC2 = m_SpecificCondCode(ISD::SETUGE); + + if (sd_match(N1, m_SelectCCLike(MS0, MVY, MZ, m_Deferred(Y), MCC1)) || + sd_match(N1, m_SelectCCLike(MS0, MVY, m_Deferred(Y), MZ, MCC2)) || + sd_match(N1, m_VSelect(m_SetCC(MS0, MVY, MCC1), MZ, m_Deferred(Y))) || + sd_match(N1, m_VSelect(m_SetCC(MS0, MVY, MCC2), m_Deferred(Y), MZ))) + return DAG.getNode(ISD::UMIN, DL, VT, N0, DAG.getNode(ISD::SUB, DL, VT, N0, Y)); } @@ -10616,6 +10622,19 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { return DAG.getVScale(DL, VT, C0 << C1); } + SDValue X; + APInt VS0; + + // fold (shl (X * vscale(VS0)), C1) -> (X * vscale(VS0 << C1)) + if (N1C && sd_match(N0, m_Mul(m_Value(X), m_VScale(m_ConstInt(VS0))))) { + SDNodeFlags Flags; + Flags.setNoUnsignedWrap(N->getFlags().hasNoUnsignedWrap() && + N0->getFlags().hasNoUnsignedWrap()); + + SDValue VScale = DAG.getVScale(DL, VT, VS0 << N1C->getAPIntValue()); + return DAG.getNode(ISD::MUL, DL, VT, X, VScale, Flags); + } + // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)). APInt ShlVal; if (N0.getOpcode() == ISD::STEP_VECTOR && @@ -25282,6 +25301,28 @@ static SDValue combineConcatVectorOfShuffleAndItsOperands( return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask); } +static SDValue combineConcatVectorOfSplats(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI, + bool LegalTypes, + bool LegalOperations) { + EVT VT = N->getValueType(0); + + // Post-legalization we can only create wider SPLAT_VECTOR operations if both + // the type and operation is legal. The Hexagon target has custom + // legalization for SPLAT_VECTOR that splits the operation into two parts and + // concatenates them. Therefore, custom lowering must also be rejected in + // order to avoid an infinite loop. + if ((LegalTypes && !TLI.isTypeLegal(VT)) || + (LegalOperations && !TLI.isOperationLegal(ISD::SPLAT_VECTOR, VT))) + return SDValue(); + + SDValue Op0 = N->getOperand(0); + if (!llvm::all_equal(N->op_values()) || Op0.getOpcode() != ISD::SPLAT_VECTOR) + return SDValue(); + + return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, Op0.getOperand(0)); +} + SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // If we only have one input vector, we don't need to do any concatenation. if (N->getNumOperands() == 1) @@ -25405,6 +25446,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return DAG.getBuildVector(VT, SDLoc(N), Opnds); } + if (SDValue V = + combineConcatVectorOfSplats(N, DAG, TLI, LegalTypes, LegalOperations)) + return V; + // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR. // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...). if (SDValue V = combineConcatVectorOfScalars(N, DAG)) @@ -25473,6 +25518,21 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitVECTOR_INTERLEAVE(SDNode *N) { + // Check to see if all operands are identical. + if (!llvm::all_equal(N->op_values())) + return SDValue(); + + // Check to see if the identical operand is a splat. + if (!DAG.isSplatValue(N->getOperand(0))) + return SDValue(); + + // interleave splat(X), splat(X).... --> splat(X), splat(X).... + SmallVector<SDValue, 4> Ops; + Ops.append(N->op_values().begin(), N->op_values().end()); + return CombineTo(N, &Ops); +} + // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find // if the subvector can be sourced for free. static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) { @@ -28965,13 +29025,27 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, ((N1C->isAllOnes() && CC == ISD::SETGT) || (N1C->isZero() && CC == ISD::SETLT)) && !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) { - SDValue ASR = DAG.getNode( - ISD::SRA, DL, CmpOpVT, N0, - DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT)); - return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT), + SDValue ASHR = + DAG.getNode(ISD::SRA, DL, CmpOpVT, N0, + DAG.getShiftAmountConstant( + CmpOpVT.getScalarSizeInBits() - 1, CmpOpVT, DL)); + return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASHR, DL, VT), DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT)); } + // Fold sign pattern select_cc setgt X, -1, 1, -1 -> or (ashr X, BW-1), 1 + if (CC == ISD::SETGT && N1C && N2C && N3C && N1C->isAllOnes() && + N2C->isOne() && N3C->isAllOnes() && + !TLI.shouldAvoidTransformToShift(CmpOpVT, + CmpOpVT.getScalarSizeInBits() - 1)) { + SDValue ASHR = + DAG.getNode(ISD::SRA, DL, CmpOpVT, N0, + DAG.getShiftAmountConstant( + CmpOpVT.getScalarSizeInBits() - 1, CmpOpVT, DL)); + return DAG.getNode(ISD::OR, DL, VT, DAG.getSExtOrTrunc(ASHR, DL, VT), + DAG.getConstant(1, DL, VT)); + } + if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG)) return S; if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG)) diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp index 583a85a..a5bd97a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -2217,8 +2217,9 @@ SDValue DAGTypeLegalizer::PromoteIntOp_BITCAST(SDNode *N) { switch (getTypeAction(InVT)) { case TargetLowering::TypePromoteInteger: { - // TODO: Handle big endian - if (OutVT.isVector() && DAG.getDataLayout().isLittleEndian()) { + // TODO: Handle big endian & vector input type. + if (OutVT.isVector() && !InVT.isVector() && + DAG.getDataLayout().isLittleEndian()) { EVT EltVT = OutVT.getVectorElementType(); TypeSize EltSize = EltVT.getSizeInBits(); TypeSize NInSize = NInVT.getSizeInBits(); diff --git a/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp b/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp index 6a2e782..31e7855 100644 --- a/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp @@ -888,7 +888,8 @@ EmitSchedule(MachineBasicBlock::iterator &InsertPos) { } if (MI->isCandidateForAdditionalCallInfo()) { - if (DAG->getTarget().Options.EmitCallSiteInfo) + if (DAG->getTarget().Options.EmitCallSiteInfo || + DAG->getTarget().Options.EmitCallGraphSection) MF.addCallSiteInfo(MI, DAG->getCallSiteInfo(Node)); if (auto CalledGlobal = DAG->getCalledGlobal(Node)) diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 02d1100..f41b6eb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -12782,7 +12782,7 @@ bool SDNode::areOnlyUsersOf(ArrayRef<const SDNode *> Nodes, const SDNode *N) { return Seen; } -/// isOperand - Return true if this node is an operand of N. +/// Return true if the referenced return value is an operand of N. bool SDValue::isOperandOf(const SDNode *N) const { return is_contained(N->op_values(), *this); } diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 48d6b99..e3f2a193 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -6482,8 +6482,8 @@ SDValue TargetLowering::buildSDIVPow2WithCMov( Created.push_back(CMov.getNode()); // Divide by pow2. - SDValue SRA = - DAG.getNode(ISD::SRA, DL, VT, CMov, DAG.getConstant(Lg2, DL, VT)); + SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, CMov, + DAG.getShiftAmountConstant(Lg2, VT, DL)); // If we're dividing by a positive value, we're done. Otherwise, we must // negate the result. diff --git a/llvm/lib/CodeGen/TailDuplicator.cpp b/llvm/lib/CodeGen/TailDuplicator.cpp index a88c57f..5d720fb 100644 --- a/llvm/lib/CodeGen/TailDuplicator.cpp +++ b/llvm/lib/CodeGen/TailDuplicator.cpp @@ -604,12 +604,21 @@ bool TailDuplicator::shouldTailDuplicate(bool IsSimple, bool HasComputedGoto = false; if (!TailBB.empty()) { HasIndirectbr = TailBB.back().isIndirectBranch(); - HasComputedGoto = TailBB.terminatorIsComputedGoto(); + HasComputedGoto = TailBB.terminatorIsComputedGotoWithSuccessors(); } if (HasIndirectbr && PreRegAlloc) MaxDuplicateCount = TailDupIndirectBranchSize; + // Allow higher limits when the block has computed-gotos and running after + // register allocation. NB. This basically unfactors computed gotos that were + // factored early on in the compilation process to speed up edge based data + // flow. If we do not unfactor them again, it can seriously pessimize code + // with many computed jumps in the source code, such as interpreters. + // Therefore we do not restrict the computed gotos. + if (HasComputedGoto && !PreRegAlloc) + MaxDuplicateCount = std::max(MaxDuplicateCount, 10u); + // Check the instructions in the block to determine whether tail-duplication // is invalid or unlikely to be profitable. unsigned InstrCount = 0; @@ -663,12 +672,7 @@ bool TailDuplicator::shouldTailDuplicate(bool IsSimple, // Duplicating a BB which has both multiple predecessors and successors will // may cause huge amount of PHI nodes. If we want to remove this limitation, // we have to address https://github.com/llvm/llvm-project/issues/78578. - // NB. This basically unfactors computed gotos that were factored early on in - // the compilation process to speed up edge based data flow. If we do not - // unfactor them again, it can seriously pessimize code with many computed - // jumps in the source code, such as interpreters. Therefore we do not - // restrict the computed gotos. - if (!HasComputedGoto && TailBB.pred_size() > TailDupPredSize && + if (PreRegAlloc && TailBB.pred_size() > TailDupPredSize && TailBB.succ_size() > TailDupSuccSize) { // If TailBB or any of its successors contains a phi, we may have to add a // large number of additional phis with additional incoming values. diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp index 18d6bbc..705e046e 100644 --- a/llvm/lib/CodeGen/TargetInstrInfo.cpp +++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp @@ -1406,7 +1406,7 @@ void TargetInstrInfo::reassociateOps( const MCInstrDesc &MCID, Register DestReg) { return MachineInstrBuilder( MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true)) - .setPCSections(MIMD.getPCSections()) + .copyMIMetadata(MIMD) .addReg(DestReg, RegState::Define); }; diff --git a/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp b/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp index 222dc88..6ddb12b 100644 --- a/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp +++ b/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp @@ -413,6 +413,117 @@ static bool isTlsAddressCode(uint8_t DW_OP_Code) { DW_OP_Code == dwarf::DW_OP_GNU_push_tls_address; } +static void constructSeqOffsettoOrigRowMapping( + CompileUnit &Unit, const DWARFDebugLine::LineTable <, + DenseMap<uint64_t, unsigned> &SeqOffToOrigRow) { + + // Use std::map for ordered iteration. + std::map<uint64_t, unsigned> LineTableMapping; + + // First, trust the sequences that the DWARF parser did identify. + for (const DWARFDebugLine::Sequence &Seq : LT.Sequences) + LineTableMapping[Seq.StmtSeqOffset] = Seq.FirstRowIndex; + + // Second, manually find sequence boundaries and match them to the + // sorted attributes to handle sequences the parser might have missed. + auto StmtAttrs = Unit.getStmtSeqListAttributes(); + llvm::sort(StmtAttrs, [](const PatchLocation &A, const PatchLocation &B) { + return A.get() < B.get(); + }); + + std::vector<unsigned> SeqStartRows; + SeqStartRows.push_back(0); + for (auto [I, Row] : llvm::enumerate(ArrayRef(LT.Rows).drop_back())) + if (Row.EndSequence) + SeqStartRows.push_back(I + 1); + + // While SeqOffToOrigRow parsed from CU could be the ground truth, + // e.g. + // + // SeqOff Row + // 0x08 9 + // 0x14 15 + // + // The StmtAttrs and SeqStartRows may not match perfectly, e.g. + // + // StmtAttrs SeqStartRows + // 0x04 3 + // 0x08 5 + // 0x10 9 + // 0x12 11 + // 0x14 15 + // + // In this case, we don't want to assign 5 to 0x08, since we know 0x08 + // maps to 9. If we do a dummy 1:1 mapping 0x10 will be mapped to 9 + // which is incorrect. The expected behavior is ignore 5, realign the + // table based on the result from the line table: + // + // StmtAttrs SeqStartRows + // 0x04 3 + // -- 5 + // 0x08 9 <- LineTableMapping ground truth + // 0x10 11 + // 0x12 -- + // 0x14 15 <- LineTableMapping ground truth + + ArrayRef StmtAttrsRef(StmtAttrs); + ArrayRef SeqStartRowsRef(SeqStartRows); + + // Dummy last element to make sure StmtAttrsRef and SeqStartRowsRef always + // run out first. + constexpr uint64_t DummyKey = UINT64_MAX; + constexpr unsigned DummyVal = UINT32_MAX; + LineTableMapping[DummyKey] = DummyVal; + + for (auto [NextSeqOff, NextRow] : LineTableMapping) { + // Explict capture to avoid capturing structured bindings and make C++17 + // happy. + auto StmtAttrSmallerThanNext = [N = NextSeqOff](const PatchLocation &SA) { + return SA.get() < N; + }; + auto SeqStartSmallerThanNext = [N = NextRow](const unsigned &Row) { + return Row < N; + }; + // If both StmtAttrs and SeqStartRows points to value not in + // the LineTableMapping yet, we do a dummy one to one mapping and + // move the pointer. + while (!StmtAttrsRef.empty() && !SeqStartRowsRef.empty() && + StmtAttrSmallerThanNext(StmtAttrsRef.front()) && + SeqStartSmallerThanNext(SeqStartRowsRef.front())) { + SeqOffToOrigRow[StmtAttrsRef.consume_front().get()] = + SeqStartRowsRef.consume_front(); + } + // One of the pointer points to the value at or past Next in the + // LineTableMapping, We move the pointer to re-align with the + // LineTableMapping + StmtAttrsRef = StmtAttrsRef.drop_while(StmtAttrSmallerThanNext); + SeqStartRowsRef = SeqStartRowsRef.drop_while(SeqStartSmallerThanNext); + // Use the LineTableMapping's result as the ground truth and move + // on. + if (NextSeqOff != DummyKey) { + SeqOffToOrigRow[NextSeqOff] = NextRow; + } + // Move the pointers if they are pointed at Next. + // It is possible that they point to later entries in LineTableMapping. + // Therefore we only increment the pointers after we validate they are + // pointing to the `Next` entry. e.g. + // + // LineTableMapping + // SeqOff Row + // 0x08 9 <- NextSeqOff/NextRow + // 0x14 15 + // + // StmtAttrs SeqStartRows + // 0x14 13 <- StmtAttrsRef.front() / SeqStartRowsRef.front() + // 0x16 15 + // -- 17 + if (!StmtAttrsRef.empty() && StmtAttrsRef.front().get() == NextSeqOff) + StmtAttrsRef.consume_front(); + if (!SeqStartRowsRef.empty() && SeqStartRowsRef.front() == NextRow) + SeqStartRowsRef.consume_front(); + } +} + std::pair<bool, std::optional<int64_t>> DWARFLinker::getVariableRelocAdjustment(AddressesMap &RelocMgr, const DWARFDie &DIE) { @@ -2297,8 +2408,12 @@ void DWARFLinker::DIECloner::generateLineTableForUnit(CompileUnit &Unit) { // Create a map of stmt sequence offsets to original row indices. DenseMap<uint64_t, unsigned> SeqOffToOrigRow; - for (const DWARFDebugLine::Sequence &Seq : LT->Sequences) - SeqOffToOrigRow[Seq.StmtSeqOffset] = Seq.FirstRowIndex; + // The DWARF parser's discovery of sequences can be incomplete. To + // ensure all DW_AT_LLVM_stmt_sequence attributes can be patched, we + // build a map from both the parser's results and a manual + // reconstruction. + if (!LT->Rows.empty()) + constructSeqOffsettoOrigRowMapping(Unit, *LT, SeqOffToOrigRow); // Create a map of original row indices to new row indices. DenseMap<size_t, size_t> OrigRowToNewRow; diff --git a/llvm/lib/ExecutionEngine/RuntimeDyld/RuntimeDyld.cpp b/llvm/lib/ExecutionEngine/RuntimeDyld/RuntimeDyld.cpp index a8559e7..6de6cc7 100644 --- a/llvm/lib/ExecutionEngine/RuntimeDyld/RuntimeDyld.cpp +++ b/llvm/lib/ExecutionEngine/RuntimeDyld/RuntimeDyld.cpp @@ -891,7 +891,7 @@ RuntimeDyldImpl::emitSection(const ObjectFile &Obj, // Align DataSize to stub alignment if we have any stubs (PaddingSize will // have been increased above to account for this). if (StubBufSize > 0) - DataSize &= -(uint64_t)getStubAlignment().value(); + DataSize &= -getStubAlignment().value(); } LLVM_DEBUG(dbgs() << "emitSection SectionID: " << SectionID << " Name: " diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt index 5343469..3d22577 100644 --- a/llvm/lib/Frontend/HLSL/CMakeLists.txt +++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMFrontendHLSL CBuffer.cpp + HLSLBinding.cpp HLSLResource.cpp HLSLRootSignature.cpp RootSignatureMetadata.cpp diff --git a/llvm/lib/Frontend/HLSL/HLSLBinding.cpp b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp new file mode 100644 index 0000000..d581311 --- /dev/null +++ b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp @@ -0,0 +1,142 @@ +//===- HLSLBinding.cpp - Representation for resource bindings in HLSL -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLBinding.h" +#include "llvm/ADT/STLExtras.h" + +using namespace llvm; +using namespace hlsl; + +std::optional<uint32_t> +BindingInfo::findAvailableBinding(dxil::ResourceClass RC, uint32_t Space, + int32_t Size) { + BindingSpaces &BS = getBindingSpaces(RC); + RegisterSpace &RS = BS.getOrInsertSpace(Space); + return RS.findAvailableBinding(Size); +} + +BindingInfo::RegisterSpace & +BindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) { + for (auto It = Spaces.begin(), End = Spaces.end(); It != End; ++It) { + if (It->Space == Space) + return *It; + if (It->Space < Space) + continue; + return *Spaces.insert(It, Space); + } + return Spaces.emplace_back(Space); +} + +std::optional<uint32_t> +BindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) { + assert((Size == -1 || Size > 0) && "invalid size"); + + if (FreeRanges.empty()) + return std::nullopt; + + // unbounded array + if (Size == -1) { + BindingRange &Last = FreeRanges.back(); + if (Last.UpperBound != ~0u) + // this space is already occupied by an unbounded array + return std::nullopt; + uint32_t RegSlot = Last.LowerBound; + FreeRanges.pop_back(); + return RegSlot; + } + + // single resource or fixed-size array + for (BindingRange &R : FreeRanges) { + // compare the size as uint64_t to prevent overflow for range (0, ~0u) + if ((uint64_t)R.UpperBound - R.LowerBound + 1 < (uint64_t)Size) + continue; + uint32_t RegSlot = R.LowerBound; + // This might create a range where (LowerBound == UpperBound + 1). When + // that happens, the next time this function is called the range will + // skipped over by the check above (at this point Size is always > 0). + R.LowerBound += Size; + return RegSlot; + } + + return std::nullopt; +} + +BindingInfo BindingInfoBuilder::calculateBindingInfo( + llvm::function_ref<void(const BindingInfoBuilder &Builder, + const Binding &Overlapping)> + ReportOverlap) { + // sort all the collected bindings + llvm::stable_sort(Bindings); + + // remove duplicates + Binding *NewEnd = llvm::unique(Bindings); + if (NewEnd != Bindings.end()) + Bindings.erase(NewEnd); + + BindingInfo Info; + + // Go over the sorted bindings and build up lists of free register ranges + // for each binding type and used spaces. Bindings are sorted by resource + // class, space, and lower bound register slot. + BindingInfo::BindingSpaces *BS = + &Info.getBindingSpaces(dxil::ResourceClass::SRV); + for (const Binding &B : Bindings) { + if (BS->RC != B.RC) + // move to the next resource class spaces + BS = &Info.getBindingSpaces(B.RC); + + BindingInfo::RegisterSpace *S = BS->Spaces.empty() + ? &BS->Spaces.emplace_back(B.Space) + : &BS->Spaces.back(); + assert(S->Space <= B.Space && "bindings not sorted correctly?"); + if (B.Space != S->Space) + // add new space + S = &BS->Spaces.emplace_back(B.Space); + + // The space is full - there are no free slots left, or the rest of the + // slots are taken by an unbounded array. Report the overlapping to the + // caller. + if (S->FreeRanges.empty() || S->FreeRanges.back().UpperBound < ~0u) { + ReportOverlap(*this, B); + continue; + } + // adjust the last free range lower bound, split it in two, or remove it + BindingInfo::BindingRange &LastFreeRange = S->FreeRanges.back(); + if (LastFreeRange.LowerBound == B.LowerBound) { + if (B.UpperBound < ~0u) + LastFreeRange.LowerBound = B.UpperBound + 1; + else + S->FreeRanges.pop_back(); + } else if (LastFreeRange.LowerBound < B.LowerBound) { + LastFreeRange.UpperBound = B.LowerBound - 1; + if (B.UpperBound < ~0u) + S->FreeRanges.emplace_back(B.UpperBound + 1, ~0u); + } else { + // We don't have room here. Report the overlapping binding to the caller + // and mark any extra space this binding would use as unavailable. + ReportOverlap(*this, B); + if (B.UpperBound < ~0u) + LastFreeRange.LowerBound = + std::max(LastFreeRange.LowerBound, B.UpperBound + 1); + else + S->FreeRanges.pop_back(); + } + } + + return Info; +} + +const BindingInfoBuilder::Binding &BindingInfoBuilder::findOverlapping( + const BindingInfoBuilder::Binding &ReportedBinding) const { + for (const BindingInfoBuilder::Binding &Other : Bindings) + if (ReportedBinding.LowerBound <= Other.UpperBound && + Other.LowerBound <= ReportedBinding.UpperBound) + return Other; + + llvm_unreachable("Searching for overlap for binding that does not overlap"); +} diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 53f5934..48ff1ca 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -13,15 +13,21 @@ #include "llvm/Frontend/HLSL/RootSignatureMetadata.h" #include "llvm/Frontend/HLSL/RootSignatureValidations.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/ScopedPrinter.h" +using namespace llvm; + namespace llvm { namespace hlsl { namespace rootsig { +char GenericRSMetadataError::ID; +char InvalidRSMetadataFormat::ID; +char InvalidRSMetadataValue::ID; +template <typename T> char RootSignatureValidationError<T>::ID; + static std::optional<uint32_t> extractMdIntValue(MDNode *Node, unsigned int OpId) { if (auto *CI = @@ -45,19 +51,6 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node, return NodeText->getString(); } -static bool reportError(LLVMContext *Ctx, Twine Message, - DiagnosticSeverity Severity = DS_Error) { - Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity)); - return true; -} - -static bool reportValueError(LLVMContext *Ctx, Twine ParamName, - uint32_t Value) { - Ctx->diagnose(DiagnosticInfoGeneric( - "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error)); - return true; -} - static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = { {"CBV", dxil::ResourceClass::CBuffer}, {"SRV", dxil::ResourceClass::SRV}, @@ -120,7 +113,7 @@ MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) { IRBuilder<> Builder(Ctx); Metadata *Operands[] = { MDString::get(Ctx, "RootFlags"), - ConstantAsMetadata::get(Builder.getInt32(llvm::to_underlying(Flags))), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Flags))), }; return MDNode::get(Ctx, Operands); } @@ -130,7 +123,7 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { Metadata *Operands[] = { MDString::get(Ctx, "RootConstants"), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Constants.Visibility))), + Builder.getInt32(to_underlying(Constants.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Constants.Space)), ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)), @@ -140,18 +133,18 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { IRBuilder<> Builder(Ctx); - std::optional<StringRef> ResName = getResourceName( - dxil::ResourceClass(llvm::to_underlying(Descriptor.Type))); + std::optional<StringRef> ResName = + getResourceName(dxil::ResourceClass(to_underlying(Descriptor.Type))); assert(ResName && "Provided an invalid Resource Class"); - llvm::SmallString<7> Name({"Root", *ResName}); + SmallString<7> Name({"Root", *ResName}); Metadata *Operands[] = { MDString::get(Ctx, Name), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))), + Builder.getInt32(to_underlying(Descriptor.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Flags))), + Builder.getInt32(to_underlying(Descriptor.Flags))), }; return MDNode::get(Ctx, Operands); } @@ -162,7 +155,7 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { // Set the mandatory arguments TableOperands.push_back(MDString::get(Ctx, "DescriptorTable")); TableOperands.push_back(ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Table.Visibility)))); + Builder.getInt32(to_underlying(Table.Visibility)))); // Remaining operands are references to the table's clauses. The in-memory // representation of the Root Elements created from parsing will ensure that @@ -182,7 +175,7 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause( const DescriptorTableClause &Clause) { IRBuilder<> Builder(Ctx); std::optional<StringRef> ResName = - getResourceName(dxil::ResourceClass(llvm::to_underlying(Clause.Type))); + getResourceName(dxil::ResourceClass(to_underlying(Clause.Type))); assert(ResName && "Provided an invalid Resource Class"); Metadata *Operands[] = { MDString::get(Ctx, *ResName), @@ -190,8 +183,7 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause( ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Clause.Flags))), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Clause.Flags))), }; return MDNode::get(Ctx, Operands); } @@ -200,108 +192,102 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { IRBuilder<> Builder(Ctx); Metadata *Operands[] = { MDString::get(Ctx, "StaticSampler"), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Sampler.Filter))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Filter))), + Builder.getInt32(to_underlying(Sampler.AddressU))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressU))), + Builder.getInt32(to_underlying(Sampler.AddressV))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressV))), + Builder.getInt32(to_underlying(Sampler.AddressW))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressW))), - ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), - Sampler.MipLODBias)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MipLODBias)), ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))), + Builder.getInt32(to_underlying(Sampler.CompFunc))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))), + Builder.getInt32(to_underlying(Sampler.BorderColor))), ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MinLOD)), ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MaxLOD)), ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Visibility))), + Builder.getInt32(to_underlying(Sampler.Visibility))), }; return MDNode::get(Ctx, Operands); } -bool MetadataParser::parseRootFlags(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *RootFlagNode) { - +Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD, + MDNode *RootFlagNode) { if (RootFlagNode->getNumOperands() != 2) - return reportError(Ctx, "Invalid format for RootFlag Element"); + return make_error<InvalidRSMetadataFormat>("RootFlag Element"); if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1)) RSD.Flags = *Val; else - return reportError(Ctx, "Invalid value for RootFlag"); + return make_error<InvalidRSMetadataValue>("RootFlag"); - return false; + return Error::success(); } -bool MetadataParser::parseRootConstants(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *RootConstantNode) { - +Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, + MDNode *RootConstantNode) { if (RootConstantNode->getNumOperands() != 5) - return reportError(Ctx, "Invalid format for RootConstants Element"); + return make_error<InvalidRSMetadataFormat>("RootConstants Element"); dxbc::RTS0::v1::RootParameterHeader Header; // The parameter offset doesn't matter here - we recalculate it during // serialization Header.ParameterOffset = 0; - Header.ParameterType = - llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); + Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); dxbc::RTS0::v1::RootConstants Constants; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) Constants.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) Constants.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) Constants.Num32BitValues = *Val; else - return reportError(Ctx, "Invalid value for Num32BitValues"); + return make_error<InvalidRSMetadataValue>("Num32BitValues"); RSD.ParametersContainer.addParameter(Header, Constants); - return false; + return Error::success(); } -bool MetadataParser::parseRootDescriptors( - LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, - MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) { +Error MetadataParser::parseRootDescriptors( + mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode, + RootSignatureElementKind ElementKind) { assert(ElementKind == RootSignatureElementKind::SRV || ElementKind == RootSignatureElementKind::UAV || ElementKind == RootSignatureElementKind::CBV && "parseRootDescriptors should only be called with RootDescriptor " "element kind."); if (RootDescriptorNode->getNumOperands() != 5) - return reportError(Ctx, "Invalid format for Root Descriptor Element"); + return make_error<InvalidRSMetadataFormat>("Root Descriptor Element"); dxbc::RTS0::v1::RootParameterHeader Header; switch (ElementKind) { case RootSignatureElementKind::SRV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV); + Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV); break; case RootSignatureElementKind::UAV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV); + Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV); break; case RootSignatureElementKind::CBV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV); + Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV); break; default: llvm_unreachable("invalid Root Descriptor kind"); @@ -311,40 +297,38 @@ bool MetadataParser::parseRootDescriptors( if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); dxbc::RTS0::v2::RootDescriptor Descriptor; if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2)) Descriptor.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3)) Descriptor.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (RSD.Version == 1) { RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; + return Error::success(); } assert(RSD.Version > 1); if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4)) Descriptor.Flags = *Val; else - return reportError(Ctx, "Invalid value for Root Descriptor Flags"); + return make_error<InvalidRSMetadataValue>("Root Descriptor Flags"); RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; + return Error::success(); } -bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx, - mcdxbc::DescriptorTable &Table, - MDNode *RangeDescriptorNode) { - +Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, + MDNode *RangeDescriptorNode) { if (RangeDescriptorNode->getNumOperands() != 6) - return reportError(Ctx, "Invalid format for Descriptor Range"); + return make_error<InvalidRSMetadataFormat>("Descriptor Range"); dxbc::RTS0::v2::DescriptorRange Range; @@ -352,162 +336,161 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx, extractMdStringValue(RangeDescriptorNode, 0); if (!ElementText.has_value()) - return reportError(Ctx, "Descriptor Range, first element is not a string."); + return make_error<InvalidRSMetadataFormat>("Descriptor Range"); Range.RangeType = StringSwitch<uint32_t>(*ElementText) - .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV)) - .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV)) - .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV)) - .Case("Sampler", - llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)) + .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV)) + .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV)) + .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV)) + .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler)) .Default(~0U); if (Range.RangeType == ~0U) - return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText); + return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.", + RangeDescriptorNode); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1)) Range.NumDescriptors = *Val; else - return reportError(Ctx, "Invalid value for Number of Descriptor in Range"); + return make_error<GenericRSMetadataError>("Number of Descriptor in Range", + RangeDescriptorNode); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2)) Range.BaseShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for BaseShaderRegister"); + return make_error<InvalidRSMetadataValue>("BaseShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3)) Range.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4)) Range.OffsetInDescriptorsFromTableStart = *Val; else - return reportError(Ctx, - "Invalid value for OffsetInDescriptorsFromTableStart"); + return make_error<InvalidRSMetadataValue>( + "OffsetInDescriptorsFromTableStart"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5)) Range.Flags = *Val; else - return reportError(Ctx, "Invalid value for Descriptor Range Flags"); + return make_error<InvalidRSMetadataValue>("Descriptor Range Flags"); Table.Ranges.push_back(Range); - return false; + return Error::success(); } -bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *DescriptorTableNode) { +Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, + MDNode *DescriptorTableNode) { const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); if (NumOperands < 2) - return reportError(Ctx, "Invalid format for Descriptor Table"); + return make_error<InvalidRSMetadataFormat>("Descriptor Table"); dxbc::RTS0::v1::RootParameterHeader Header; if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); mcdxbc::DescriptorTable Table; Header.ParameterType = - llvm::to_underlying(dxbc::RootParameterType::DescriptorTable); + to_underlying(dxbc::RootParameterType::DescriptorTable); for (unsigned int I = 2; I < NumOperands; I++) { MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I)); if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); + return make_error<GenericRSMetadataError>( + "Missing Root Element Metadata Node.", DescriptorTableNode); - if (parseDescriptorRange(Ctx, Table, Element)) - return true; + if (auto Err = parseDescriptorRange(Table, Element)) + return Err; } RSD.ParametersContainer.addParameter(Header, Table); - return false; + return Error::success(); } -bool MetadataParser::parseStaticSampler(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *StaticSamplerNode) { +Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, + MDNode *StaticSamplerNode) { if (StaticSamplerNode->getNumOperands() != 14) - return reportError(Ctx, "Invalid format for Static Sampler"); + return make_error<InvalidRSMetadataFormat>("Static Sampler"); dxbc::RTS0::v1::StaticSampler Sampler; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1)) Sampler.Filter = *Val; else - return reportError(Ctx, "Invalid value for Filter"); + return make_error<InvalidRSMetadataValue>("Filter"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2)) Sampler.AddressU = *Val; else - return reportError(Ctx, "Invalid value for AddressU"); + return make_error<InvalidRSMetadataValue>("AddressU"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3)) Sampler.AddressV = *Val; else - return reportError(Ctx, "Invalid value for AddressV"); + return make_error<InvalidRSMetadataValue>("AddressV"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4)) Sampler.AddressW = *Val; else - return reportError(Ctx, "Invalid value for AddressW"); + return make_error<InvalidRSMetadataValue>("AddressW"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5)) Sampler.MipLODBias = *Val; else - return reportError(Ctx, "Invalid value for MipLODBias"); + return make_error<InvalidRSMetadataValue>("MipLODBias"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6)) Sampler.MaxAnisotropy = *Val; else - return reportError(Ctx, "Invalid value for MaxAnisotropy"); + return make_error<InvalidRSMetadataValue>("MaxAnisotropy"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7)) Sampler.ComparisonFunc = *Val; else - return reportError(Ctx, "Invalid value for ComparisonFunc "); + return make_error<InvalidRSMetadataValue>("ComparisonFunc"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8)) Sampler.BorderColor = *Val; else - return reportError(Ctx, "Invalid value for ComparisonFunc "); + return make_error<InvalidRSMetadataValue>("ComparisonFunc"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9)) Sampler.MinLOD = *Val; else - return reportError(Ctx, "Invalid value for MinLOD"); + return make_error<InvalidRSMetadataValue>("MinLOD"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10)) Sampler.MaxLOD = *Val; else - return reportError(Ctx, "Invalid value for MaxLOD"); + return make_error<InvalidRSMetadataValue>("MaxLOD"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11)) Sampler.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12)) Sampler.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13)) Sampler.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); RSD.StaticSamplers.push_back(Sampler); - return false; + return Error::success(); } -bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *Element) { +Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD, + MDNode *Element) { std::optional<StringRef> ElementText = extractMdStringValue(Element, 0); if (!ElementText.has_value()) - return reportError(Ctx, "Invalid format for Root Element"); + return make_error<InvalidRSMetadataFormat>("Root Element"); RootSignatureElementKind ElementKind = StringSwitch<RootSignatureElementKind>(*ElementText) @@ -523,79 +506,109 @@ bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, switch (ElementKind) { case RootSignatureElementKind::RootFlags: - return parseRootFlags(Ctx, RSD, Element); + return parseRootFlags(RSD, Element); case RootSignatureElementKind::RootConstants: - return parseRootConstants(Ctx, RSD, Element); + return parseRootConstants(RSD, Element); case RootSignatureElementKind::CBV: case RootSignatureElementKind::SRV: case RootSignatureElementKind::UAV: - return parseRootDescriptors(Ctx, RSD, Element, ElementKind); + return parseRootDescriptors(RSD, Element, ElementKind); case RootSignatureElementKind::DescriptorTable: - return parseDescriptorTable(Ctx, RSD, Element); + return parseDescriptorTable(RSD, Element); case RootSignatureElementKind::StaticSamplers: - return parseStaticSampler(Ctx, RSD, Element); + return parseStaticSampler(RSD, Element); case RootSignatureElementKind::Error: - return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText); + return make_error<GenericRSMetadataError>("Invalid Root Signature Element", + Element); } llvm_unreachable("Unhandled RootSignatureElementKind enum."); } -bool MetadataParser::validateRootSignature( - LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) { - if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) { - return reportValueError(Ctx, "Version", RSD.Version); +Error MetadataParser::validateRootSignature( + const mcdxbc::RootSignatureDesc &RSD) { + Error DeferredErrs = Error::success(); + if (!hlsl::rootsig::verifyVersion(RSD.Version)) { + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Version", RSD.Version)); } - if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) { - return reportValueError(Ctx, "RootFlags", RSD.Flags); + if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) { + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RootFlags", RSD.Flags)); } for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Info.Header.ShaderVisibility); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", Info.Header.ShaderVisibility)); assert(dxbc::isValidParameterType(Info.Header.ParameterType) && "Invalid value for ParameterType"); switch (Info.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): { + case to_underlying(dxbc::RootParameterType::CBV): + case to_underlying(dxbc::RootParameterType::UAV): + case to_underlying(dxbc::RootParameterType::SRV): { const dxbc::RTS0::v2::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location); - if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", - Descriptor.ShaderRegister); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace); + if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderRegister", Descriptor.ShaderRegister)); + + if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Descriptor.RegisterSpace)); if (RSD.Version > 1) { - if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) - return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags); + if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, + Descriptor.Flags)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RootDescriptorFlag", Descriptor.Flags)); } break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case to_underlying(dxbc::RootParameterType::DescriptorTable): { const mcdxbc::DescriptorTable &Table = RSD.ParametersContainer.getDescriptorTable(Info.Location); for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { - if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType)) - return reportValueError(Ctx, "RangeType", Range.RangeType); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace); - - if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) - return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors); - - if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( + if (!hlsl::rootsig::verifyRangeType(Range.RangeType)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RangeType", Range.RangeType)); + + if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Range.RegisterSpace)); + + if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "NumDescriptors", Range.NumDescriptors)); + + if (!hlsl::rootsig::verifyDescriptorRangeFlag( RSD.Version, Range.RangeType, Range.Flags)) - return reportValueError(Ctx, "DescriptorFlag", Range.Flags); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "DescriptorFlag", Range.Flags)); } break; } @@ -603,65 +616,108 @@ bool MetadataParser::validateRootSignature( } for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { - if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) - return reportValueError(Ctx, "Filter", Sampler.Filter); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU)) - return reportValueError(Ctx, "AddressU", Sampler.AddressU); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV)) - return reportValueError(Ctx, "AddressV", Sampler.AddressV); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW)) - return reportValueError(Ctx, "AddressW", Sampler.AddressW); - - if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) - return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias); - - if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) - return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy); - - if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) - return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc); - - if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) - return reportValueError(Ctx, "BorderColor", Sampler.BorderColor); - - if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD)) - return reportValueError(Ctx, "MinLOD", Sampler.MinLOD); - - if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) - return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD); - - if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace); + if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Filter", Sampler.Filter)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressU)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressU", Sampler.AddressU)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressV)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressV", Sampler.AddressV)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressW)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressW", Sampler.AddressW)); + + if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MipLODBias", Sampler.MipLODBias)); + + if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "MaxAnisotropy", Sampler.MaxAnisotropy)); + + if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ComparisonFunc", Sampler.ComparisonFunc)); + + if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "BorderColor", Sampler.BorderColor)); + + if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MinLOD", Sampler.MinLOD)); + + if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MaxLOD", Sampler.MaxLOD)); + + if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderRegister", Sampler.ShaderRegister)); + + if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Sampler.RegisterSpace)); if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Sampler.ShaderVisibility); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", Sampler.ShaderVisibility)); } - return false; + return DeferredErrs; } -bool MetadataParser::ParseRootSignature(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD) { - bool HasError = false; - - // Loop through the Root Elements of the root signature. +Expected<mcdxbc::RootSignatureDesc> +MetadataParser::ParseRootSignature(uint32_t Version) { + Error DeferredErrs = Error::success(); + mcdxbc::RootSignatureDesc RSD; + RSD.Version = Version; for (const auto &Operand : Root->operands()) { MDNode *Element = dyn_cast<MDNode>(Operand); if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); + return joinErrors(std::move(DeferredErrs), + make_error<GenericRSMetadataError>( + "Missing Root Element Metadata Node.", nullptr)); - HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) || - validateRootSignature(Ctx, RSD); + if (auto Err = parseRootSignatureElement(RSD, Element)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); } - return HasError; + if (auto Err = validateRootSignature(RSD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); + + if (DeferredErrs) + return std::move(DeferredErrs); + + return std::move(RSD); } } // namespace rootsig } // namespace hlsl diff --git a/llvm/lib/Frontend/Offloading/CMakeLists.txt b/llvm/lib/Frontend/Offloading/CMakeLists.txt index 8e1ede9..9747dbd 100644 --- a/llvm/lib/Frontend/Offloading/CMakeLists.txt +++ b/llvm/lib/Frontend/Offloading/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_component_library(LLVMFrontendOffloading Utility.cpp OffloadWrapper.cpp + PropertySet.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend diff --git a/llvm/lib/Frontend/Offloading/PropertySet.cpp b/llvm/lib/Frontend/Offloading/PropertySet.cpp new file mode 100644 index 0000000..a70290d --- /dev/null +++ b/llvm/lib/Frontend/Offloading/PropertySet.cpp @@ -0,0 +1,102 @@ +///===- llvm/Frontend/Offloading/PropertySet.cpp --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/Offloading/PropertySet.h" +#include "llvm/Support/Base64.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBufferRef.h" + +using namespace llvm; +using namespace llvm::offloading; + +void llvm::offloading::writePropertiesToJSON( + const PropertySetRegistry &PSRegistry, raw_ostream &Out) { + json::OStream J(Out); + J.object([&] { + for (const auto &[CategoryName, PropSet] : PSRegistry) { + auto PropSetCapture = PropSet; + J.attributeObject(CategoryName, [&] { + for (const auto &[PropName, PropVal] : PropSetCapture) { + switch (PropVal.index()) { + case 0: + J.attribute(PropName, std::get<uint32_t>(PropVal)); + break; + case 1: + J.attribute(PropName, encodeBase64(std::get<ByteArray>(PropVal))); + break; + default: + llvm_unreachable("unsupported property type"); + } + } + }); + } + }); +} + +// note: createStringError has an overload that takes a format string, +// but it uses llvm::format instead of llvm::formatv, which does +// not work with json::Value. This is a helper function to use +// llvm::formatv with createStringError. +template <typename... Ts> auto createStringErrorV(Ts &&...Args) { + return createStringError(formatv(std::forward<Ts>(Args)...)); +} + +Expected<PropertyValue> +readPropertyValueFromJSON(const json::Value &PropValueVal) { + if (std::optional<uint64_t> Val = PropValueVal.getAsUINT64()) + return PropertyValue(static_cast<uint32_t>(*Val)); + + if (std::optional<StringRef> Val = PropValueVal.getAsString()) { + std::vector<char> Decoded; + if (Error E = decodeBase64(*Val, Decoded)) + return createStringErrorV("unable to base64 decode the string {0}: {1}", + Val, toString(std::move(E))); + return PropertyValue(ByteArray(Decoded.begin(), Decoded.end())); + } + + return createStringErrorV("expected a uint64 or a string, got {0}", + PropValueVal); +} + +Expected<PropertySetRegistry> +llvm::offloading::readPropertiesFromJSON(MemoryBufferRef Buf) { + PropertySetRegistry Res; + Expected<json::Value> V = json::parse(Buf.getBuffer()); + if (Error E = V.takeError()) + return E; + + const json::Object *O = V->getAsObject(); + if (!O) + return createStringErrorV( + "error while deserializing property set registry: " + "expected JSON object, got {0}", + *V); + + for (const auto &[CategoryName, Value] : *O) { + const json::Object *PropSetVal = Value.getAsObject(); + if (!PropSetVal) + return createStringErrorV("error while deserializing property set {0}: " + "expected JSON array, got {1}", + CategoryName.str(), Value); + + PropertySet &PropSet = Res[CategoryName.str()]; + for (const auto &[PropName, PropValueVal] : *PropSetVal) { + Expected<PropertyValue> Prop = readPropertyValueFromJSON(PropValueVal); + if (Error E = Prop.takeError()) + return createStringErrorV( + "error while deserializing property {0} in property set {1}: {2}", + PropName.str(), CategoryName.str(), toString(std::move(E))); + + auto [It, Inserted] = + PropSet.try_emplace(PropName.str(), std::move(*Prop)); + assert(Inserted && "Property already exists in PropertySet"); + (void)Inserted; + } + } + return Res; +} diff --git a/llvm/lib/IR/DebugInfoMetadata.cpp b/llvm/lib/IR/DebugInfoMetadata.cpp index f16963d..f1d4549 100644 --- a/llvm/lib/IR/DebugInfoMetadata.cpp +++ b/llvm/lib/IR/DebugInfoMetadata.cpp @@ -1012,7 +1012,7 @@ DIDerivedType *DIDerivedType::getImpl( std::optional<DIDerivedType::PtrAuthData> DIDerivedType::getPtrAuthData() const { return getTag() == dwarf::DW_TAG_LLVM_ptrauth_type - ? std::optional<PtrAuthData>(PtrAuthData(SubclassData32)) + ? std::make_optional<PtrAuthData>(SubclassData32) : std::nullopt; } diff --git a/llvm/lib/LTO/LTO.cpp b/llvm/lib/LTO/LTO.cpp index 73e79c0..0323b4d 100644 --- a/llvm/lib/LTO/LTO.cpp +++ b/llvm/lib/LTO/LTO.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/LTO/LTO.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StableHashing.h" @@ -742,18 +743,19 @@ Error LTO::add(std::unique_ptr<InputFile> Input, Conf.VisibilityScheme = Config::ELF; } - const SymbolResolution *ResI = Res.begin(); - for (unsigned I = 0; I != Input->Mods.size(); ++I) - if (Error Err = addModule(*Input, I, ResI, Res.end())) + ArrayRef<SymbolResolution> InputRes = Res; + for (unsigned I = 0; I != Input->Mods.size(); ++I) { + if (auto Err = addModule(*Input, InputRes, I, Res).moveInto(Res)) return Err; + } - assert(ResI == Res.end()); + assert(Res.empty()); return Error::success(); } -Error LTO::addModule(InputFile &Input, unsigned ModI, - const SymbolResolution *&ResI, - const SymbolResolution *ResE) { +Expected<ArrayRef<SymbolResolution>> +LTO::addModule(InputFile &Input, ArrayRef<SymbolResolution> InputRes, + unsigned ModI, ArrayRef<SymbolResolution> Res) { Expected<BitcodeLTOInfo> LTOInfo = Input.Mods[ModI].getLTOInfo(); if (!LTOInfo) return LTOInfo.takeError(); @@ -782,28 +784,32 @@ Error LTO::addModule(InputFile &Input, unsigned ModI, bool IsThinLTO = LTOInfo->IsThinLTO && (LTOMode != LTOK_UnifiedRegular); auto ModSyms = Input.module_symbols(ModI); - addModuleToGlobalRes(ModSyms, {ResI, ResE}, + addModuleToGlobalRes(ModSyms, Res, IsThinLTO ? ThinLTO.ModuleMap.size() + 1 : 0, LTOInfo->HasSummary); if (IsThinLTO) - return addThinLTO(BM, ModSyms, ResI, ResE); + return addThinLTO(BM, ModSyms, Res); RegularLTO.EmptyCombinedModule = false; - Expected<RegularLTOState::AddedModule> ModOrErr = - addRegularLTO(BM, ModSyms, ResI, ResE); + auto ModOrErr = addRegularLTO(Input, InputRes, BM, ModSyms, Res); if (!ModOrErr) return ModOrErr.takeError(); + Res = ModOrErr->second; - if (!LTOInfo->HasSummary) - return linkRegularLTO(std::move(*ModOrErr), /*LivenessFromIndex=*/false); + if (!LTOInfo->HasSummary) { + if (Error Err = linkRegularLTO(std::move(ModOrErr->first), + /*LivenessFromIndex=*/false)) + return Err; + return Res; + } // Regular LTO module summaries are added to a dummy module that represents // the combined regular LTO module. if (Error Err = BM.readSummary(ThinLTO.CombinedIndex, "")) return Err; - RegularLTO.ModsWithSummaries.push_back(std::move(*ModOrErr)); - return Error::success(); + RegularLTO.ModsWithSummaries.push_back(std::move(ModOrErr->first)); + return Res; } // Checks whether the given global value is in a non-prevailing comdat @@ -839,10 +845,11 @@ handleNonPrevailingComdat(GlobalValue &GV, // Add a regular LTO object to the link. // The resulting module needs to be linked into the combined LTO module with // linkRegularLTO. -Expected<LTO::RegularLTOState::AddedModule> -LTO::addRegularLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, - const SymbolResolution *&ResI, - const SymbolResolution *ResE) { +Expected< + std::pair<LTO::RegularLTOState::AddedModule, ArrayRef<SymbolResolution>>> +LTO::addRegularLTO(InputFile &Input, ArrayRef<SymbolResolution> InputRes, + BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, + ArrayRef<SymbolResolution> Res) { RegularLTOState::AddedModule Mod; Expected<std::unique_ptr<Module>> MOrErr = BM.getLazyModule(RegularLTO.Ctx, /*ShouldLazyLoadMetadata*/ true, @@ -855,13 +862,34 @@ LTO::addRegularLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, if (Error Err = M.materializeMetadata()) return std::move(Err); - // If cfi.functions is present and we are in regular LTO mode, LowerTypeTests - // will rename local functions in the merged module as "<function name>.1". - // This causes linking errors, since other parts of the module expect the - // original function name. - if (LTOMode == LTOK_UnifiedRegular) + if (LTOMode == LTOK_UnifiedRegular) { + // cfi.functions metadata is intended to be used with ThinLTO and may + // trigger invalid IR transformations if they are present when doing regular + // LTO, so delete it. if (NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions")) M.eraseNamedMetadata(CfiFunctionsMD); + } else if (NamedMDNode *AliasesMD = M.getNamedMetadata("aliases")) { + // Delete aliases entries for non-prevailing symbols on the ThinLTO side of + // this input file. + DenseSet<StringRef> Prevailing; + for (auto [I, R] : zip(Input.symbols(), InputRes)) + if (R.Prevailing && !I.getIRName().empty()) + Prevailing.insert(I.getIRName()); + std::vector<MDNode *> AliasGroups; + for (MDNode *AliasGroup : AliasesMD->operands()) { + std::vector<Metadata *> Aliases; + for (Metadata *Alias : AliasGroup->operands()) { + if (isa<MDString>(Alias) && + Prevailing.count(cast<MDString>(Alias)->getString())) + Aliases.push_back(Alias); + } + if (Aliases.size() > 1) + AliasGroups.push_back(MDTuple::get(RegularLTO.Ctx, Aliases)); + } + AliasesMD->clearOperands(); + for (MDNode *G : AliasGroups) + AliasesMD->addOperand(G); + } UpgradeDebugInfo(M); @@ -899,22 +927,22 @@ LTO::addRegularLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, std::set<const Comdat *> NonPrevailingComdats; SmallSet<StringRef, 2> NonPrevailingAsmSymbols; for (const InputFile::Symbol &Sym : Syms) { - assert(ResI != ResE); - SymbolResolution Res = *ResI++; + assert(!Res.empty()); + const SymbolResolution &R = Res.consume_front(); assert(MsymI != MsymE); ModuleSymbolTable::Symbol Msym = *MsymI++; Skip(); if (GlobalValue *GV = dyn_cast_if_present<GlobalValue *>(Msym)) { - if (Res.Prevailing) { + if (R.Prevailing) { if (Sym.isUndefined()) continue; Mod.Keep.push_back(GV); // For symbols re-defined with linker -wrap and -defsym options, // set the linkage to weak to inhibit IPO. The linkage will be // restored by the linker. - if (Res.LinkerRedefined) + if (R.LinkerRedefined) GV->setLinkage(GlobalValue::WeakAnyLinkage); GlobalValue::LinkageTypes OriginalLinkage = GV->getLinkage(); @@ -938,7 +966,7 @@ LTO::addRegularLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, } // Set the 'local' flag based on the linker resolution for this symbol. - if (Res.FinalDefinitionInLinkageUnit) { + if (R.FinalDefinitionInLinkageUnit) { GV->setDSOLocal(true); if (GV->hasDLLImportStorageClass()) GV->setDLLStorageClass(GlobalValue::DLLStorageClassTypes:: @@ -947,7 +975,7 @@ LTO::addRegularLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, } else if (auto *AS = dyn_cast_if_present<ModuleSymbolTable::AsmSymbol *>(Msym)) { // Collect non-prevailing symbols. - if (!Res.Prevailing) + if (!R.Prevailing) NonPrevailingAsmSymbols.insert(AS->first); } else { llvm_unreachable("unknown symbol type"); @@ -965,7 +993,7 @@ LTO::addRegularLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, CommonRes.Alignment = std::max(Align(SymAlignValue), CommonRes.Alignment); } - CommonRes.Prevailing |= Res.Prevailing; + CommonRes.Prevailing |= R.Prevailing; } } @@ -991,7 +1019,7 @@ LTO::addRegularLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, } assert(MsymI == MsymE); - return std::move(Mod); + return std::make_pair(std::move(Mod), Res); } Error LTO::linkRegularLTO(RegularLTOState::AddedModule Mod, @@ -1032,19 +1060,19 @@ Error LTO::linkRegularLTO(RegularLTOState::AddedModule Mod, } // Add a ThinLTO module to the link. -Error LTO::addThinLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, - const SymbolResolution *&ResI, - const SymbolResolution *ResE) { - const SymbolResolution *ResITmp = ResI; +Expected<ArrayRef<SymbolResolution>> +LTO::addThinLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, + ArrayRef<SymbolResolution> Res) { + ArrayRef<SymbolResolution> ResTmp = Res; for (const InputFile::Symbol &Sym : Syms) { - assert(ResITmp != ResE); - SymbolResolution Res = *ResITmp++; + assert(!ResTmp.empty()); + const SymbolResolution &R = ResTmp.consume_front(); if (!Sym.getIRName().empty()) { auto GUID = GlobalValue::getGUIDAssumingExternalLinkage( GlobalValue::getGlobalIdentifier(Sym.getIRName(), GlobalValue::ExternalLinkage, "")); - if (Res.Prevailing) + if (R.Prevailing) ThinLTO.PrevailingModuleForGUID[GUID] = BM.getModuleIdentifier(); } } @@ -1059,14 +1087,14 @@ Error LTO::addThinLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, LLVM_DEBUG(dbgs() << "Module " << BM.getModuleIdentifier() << "\n"); for (const InputFile::Symbol &Sym : Syms) { - assert(ResI != ResE); - SymbolResolution Res = *ResI++; + assert(!Res.empty()); + const SymbolResolution &R = Res.consume_front(); if (!Sym.getIRName().empty()) { auto GUID = GlobalValue::getGUIDAssumingExternalLinkage( GlobalValue::getGlobalIdentifier(Sym.getIRName(), GlobalValue::ExternalLinkage, "")); - if (Res.Prevailing) { + if (R.Prevailing) { assert(ThinLTO.PrevailingModuleForGUID[GUID] == BM.getModuleIdentifier()); @@ -1074,7 +1102,7 @@ Error LTO::addThinLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, // switch the linkage to `weak` to prevent IPOs from happening. // Find the summary in the module for this very GV and record the new // linkage so that we can switch it when we import the GV. - if (Res.LinkerRedefined) + if (R.LinkerRedefined) if (auto S = ThinLTO.CombinedIndex.findSummaryInModule( GUID, BM.getModuleIdentifier())) S->setLinkage(GlobalValue::WeakAnyLinkage); @@ -1082,7 +1110,7 @@ Error LTO::addThinLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, // If the linker resolved the symbol to a local definition then mark it // as local in the summary for the module we are adding. - if (Res.FinalDefinitionInLinkageUnit) { + if (R.FinalDefinitionInLinkageUnit) { if (auto S = ThinLTO.CombinedIndex.findSummaryInModule( GUID, BM.getModuleIdentifier())) { S->setDSOLocal(true); @@ -1110,7 +1138,7 @@ Error LTO::addThinLTO(BitcodeModule BM, ArrayRef<InputFile::Symbol> Syms, } } - return Error::success(); + return Res; } unsigned LTO::getMaxTasks() const { diff --git a/llvm/lib/MC/MCAssembler.cpp b/llvm/lib/MC/MCAssembler.cpp index 8500fd1..5c8e904 100644 --- a/llvm/lib/MC/MCAssembler.cpp +++ b/llvm/lib/MC/MCAssembler.cpp @@ -59,7 +59,8 @@ STATISTIC(EmittedFillFragments, "Number of emitted assembler fragments - fill"); STATISTIC(EmittedNopsFragments, "Number of emitted assembler fragments - nops"); STATISTIC(EmittedOrgFragments, "Number of emitted assembler fragments - org"); -STATISTIC(evaluateFixup, "Number of evaluated fixups"); +STATISTIC(Fixups, "Number of fixups"); +STATISTIC(FixupEvalForRelax, "Number of fixup evaluations for relaxation"); STATISTIC(ObjectBytes, "Number of emitted object file bytes"); STATISTIC(RelaxationSteps, "Number of assembler layout and relaxation steps"); STATISTIC(RelaxedInstructions, "Number of relaxed instructions"); @@ -140,9 +141,9 @@ bool MCAssembler::isThumbFunc(const MCSymbol *Symbol) const { bool MCAssembler::evaluateFixup(const MCFragment &F, MCFixup &Fixup, MCValue &Target, uint64_t &Value, - bool RecordReloc, - MutableArrayRef<char> Contents) const { - ++stats::evaluateFixup; + bool RecordReloc, uint8_t *Data) const { + if (RecordReloc) + ++stats::Fixups; // FIXME: This code has some duplication with recordRelocation. We should // probably merge the two into a single callback that tries to evaluate a @@ -185,7 +186,7 @@ bool MCAssembler::evaluateFixup(const MCFragment &F, MCFixup &Fixup, if (IsResolved && mc::isRelocRelocation(Fixup.getKind())) IsResolved = false; - getBackend().applyFixup(F, Fixup, Target, Contents, Value, IsResolved); + getBackend().applyFixup(F, Fixup, Target, Data, Value, IsResolved); return true; } @@ -703,21 +704,25 @@ void MCAssembler::layout() { for (MCFixup &Fixup : F.getFixups()) { uint64_t FixedValue; MCValue Target; + assert(mc::isRelocRelocation(Fixup.getKind()) || + Fixup.getOffset() <= F.getFixedSize()); + auto *Data = + reinterpret_cast<uint8_t *>(Contents.data() + Fixup.getOffset()); evaluateFixup(F, Fixup, Target, FixedValue, - /*RecordReloc=*/true, Contents); + /*RecordReloc=*/true, Data); } - if (F.getVarFixups().size()) { - // In the variable part, fixup offsets are relative to the fixed part's - // start. Extend the variable contents to the left to account for the - // fixed part size. - Contents = MutableArrayRef(F.getParent()->ContentStorage) - .slice(F.VarContentStart - Contents.size(), F.getSize()); - for (MCFixup &Fixup : F.getVarFixups()) { - uint64_t FixedValue; - MCValue Target; - evaluateFixup(F, Fixup, Target, FixedValue, - /*RecordReloc=*/true, Contents); - } + // In the variable part, fixup offsets are relative to the fixed part's + // start. + for (MCFixup &Fixup : F.getVarFixups()) { + uint64_t FixedValue; + MCValue Target; + assert(mc::isRelocRelocation(Fixup.getKind()) || + (Fixup.getOffset() >= F.getFixedSize() && + Fixup.getOffset() <= F.getSize())); + auto *Data = reinterpret_cast<uint8_t *>( + F.getVarContents().data() + (Fixup.getOffset() - F.getFixedSize())); + evaluateFixup(F, Fixup, Target, FixedValue, + /*RecordReloc=*/true, Data); } } } @@ -735,7 +740,7 @@ void MCAssembler::Finish() { bool MCAssembler::fixupNeedsRelaxation(const MCFragment &F, const MCFixup &Fixup) const { - assert(getBackendPtr() && "Expected assembler backend"); + ++stats::FixupEvalForRelax; MCValue Target; uint64_t Value; bool Resolved = evaluateFixup(F, const_cast<MCFixup &>(Fixup), Target, Value, diff --git a/llvm/lib/MC/MCCodeView.cpp b/llvm/lib/MC/MCCodeView.cpp index 7d528a5..3a5f01c 100644 --- a/llvm/lib/MC/MCCodeView.cpp +++ b/llvm/lib/MC/MCCodeView.cpp @@ -695,5 +695,7 @@ void CodeViewContext::encodeDefRange(const MCAssembler &Asm, } Frag.setVarContents(Contents); + assert(Fixups.size() < 256 && "Store fixups outside of MCFragment's VarFixup " + "storage if the number ever exceeds 256"); Frag.setVarFixups(Fixups); } diff --git a/llvm/lib/MC/MCObjectFileInfo.cpp b/llvm/lib/MC/MCObjectFileInfo.cpp index 0069d12..393eed1 100644 --- a/llvm/lib/MC/MCObjectFileInfo.cpp +++ b/llvm/lib/MC/MCObjectFileInfo.cpp @@ -537,6 +537,8 @@ void MCObjectFileInfo::initELFMCObjectFileInfo(const Triple &T, bool Large) { EHFrameSection = Ctx->getELFSection(".eh_frame", EHSectionType, EHSectionFlags); + CallGraphSection = Ctx->getELFSection(".callgraph", ELF::SHT_PROGBITS, 0); + StackSizesSection = Ctx->getELFSection(".stack_sizes", ELF::SHT_PROGBITS, 0); PseudoProbeSection = Ctx->getELFSection(".pseudo_probe", DebugSecType, 0); @@ -1121,6 +1123,24 @@ MCSection *MCObjectFileInfo::getDwarfComdatSection(const char *Name, } MCSection * +MCObjectFileInfo::getCallGraphSection(const MCSection &TextSec) const { + if (Ctx->getObjectFileType() != MCContext::IsELF) + return CallGraphSection; + + const MCSectionELF &ElfSec = static_cast<const MCSectionELF &>(TextSec); + unsigned Flags = ELF::SHF_LINK_ORDER; + StringRef GroupName; + if (const MCSymbol *Group = ElfSec.getGroup()) { + GroupName = Group->getName(); + Flags |= ELF::SHF_GROUP; + } + + return Ctx->getELFSection(".callgraph", ELF::SHT_PROGBITS, Flags, 0, + GroupName, true, ElfSec.getUniqueID(), + cast<MCSymbolELF>(TextSec.getBeginSymbol())); +} + +MCSection * MCObjectFileInfo::getStackSizesSection(const MCSection &TextSec) const { if ((Ctx->getObjectFileType() != MCContext::IsELF) || Ctx->getTargetTriple().isPS4()) diff --git a/llvm/lib/MC/MCObjectStreamer.cpp b/llvm/lib/MC/MCObjectStreamer.cpp index e82393a..200e29a 100644 --- a/llvm/lib/MC/MCObjectStreamer.cpp +++ b/llvm/lib/MC/MCObjectStreamer.cpp @@ -56,9 +56,9 @@ static_assert(FragBlockSize >= sizeof(MCFragment) + NewFragHeadroom); MCFragment *MCObjectStreamer::allocFragSpace(size_t Headroom) { auto Size = std::max(FragBlockSize, sizeof(MCFragment) + Headroom); FragSpace = Size - sizeof(MCFragment); - auto Chunk = std::unique_ptr<char[]>(new char[Size]); - auto *F = reinterpret_cast<MCFragment *>(Chunk.get()); - FragStorage.push_back(std::move(Chunk)); + auto Block = std::unique_ptr<uint8_t[]>(new uint8_t[Size]); + auto *F = reinterpret_cast<MCFragment *>(Block.get()); + FragStorage.push_back(std::move(Block)); return F; } @@ -113,9 +113,9 @@ void MCObjectStreamer::appendContents(ArrayRef<char> Contents) { FragSpace -= Contents.size(); } -void MCObjectStreamer::appendContents(size_t Num, char Elt) { +void MCObjectStreamer::appendContents(size_t Num, uint8_t Elt) { ensureHeadroom(Num); - MutableArrayRef<char> Data(getCurFragEnd(), Num); + MutableArrayRef<uint8_t> Data(getCurFragEnd(), Num); llvm::fill(Data, Elt); CurFrag->FixedSize += Num; FragSpace -= Num; diff --git a/llvm/lib/MC/MCSection.cpp b/llvm/lib/MC/MCSection.cpp index 4f28267..27ca131 100644 --- a/llvm/lib/MC/MCSection.cpp +++ b/llvm/lib/MC/MCSection.cpp @@ -83,12 +83,14 @@ void MCFragment::appendFixups(ArrayRef<MCFixup> Fixups) { } void MCFragment::setVarFixups(ArrayRef<MCFixup> Fixups) { + assert(Fixups.size() < 256 && + "variable-size tail cannot have more than 256 fixups"); auto &S = getParent()->FixupStorage; - if (VarFixupStart + Fixups.size() > VarFixupEnd) { + if (Fixups.size() > VarFixupSize) { VarFixupStart = S.size(); S.resize_for_overwrite(S.size() + Fixups.size()); } - VarFixupEnd = VarFixupStart + Fixups.size(); + VarFixupSize = Fixups.size(); // Source fixup offsets are relative to the variable part's start. Add the // fixed part size to make them relative to the fixed part's start. std::transform(Fixups.begin(), Fixups.end(), S.begin() + VarFixupStart, diff --git a/llvm/lib/ObjCopy/COFF/COFFReader.cpp b/llvm/lib/ObjCopy/COFF/COFFReader.cpp index 62a71d4..9b55f76 100644 --- a/llvm/lib/ObjCopy/COFF/COFFReader.cpp +++ b/llvm/lib/ObjCopy/COFF/COFFReader.cpp @@ -135,7 +135,7 @@ Error COFFReader::readSymbols(Object &Obj, bool IsBigObj) const { // it is, find the target section unique id. const coff_aux_section_definition *SD = SymRef.getSectionDefinition(); const coff_aux_weak_external *WE = SymRef.getWeakExternal(); - if (SD && SD->Selection == IMAGE_COMDAT_SELECT_ASSOCIATIVE) { + if (SD && SD->Selection == IMAGE_COMDAT_SELECT_ASSOCIATIVE && !Obj.IsPE) { int32_t Index = SD->getNumber(IsBigObj); if (Index <= 0 || static_cast<uint32_t>(Index - 1) >= Sections.size()) return createStringError(object_error::parse_failed, diff --git a/llvm/lib/Object/ELFObjectFile.cpp b/llvm/lib/Object/ELFObjectFile.cpp index 0919c6a..aff047c 100644 --- a/llvm/lib/Object/ELFObjectFile.cpp +++ b/llvm/lib/Object/ELFObjectFile.cpp @@ -688,11 +688,20 @@ StringRef ELFObjectFileBase::getNVPTXCPUName() const { case ELF::EF_CUDA_SM100: return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_100a" : "sm_100"; + case ELF::EF_CUDA_SM101: + return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_101a" + : "sm_101"; + case ELF::EF_CUDA_SM103: + return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_103a" + : "sm_103"; // Rubin architecture. case ELF::EF_CUDA_SM120: return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_120a" : "sm_120"; + case ELF::EF_CUDA_SM121: + return getPlatformFlags() & ELF::EF_CUDA_ACCELERATORS ? "sm_121a" + : "sm_121"; default: llvm_unreachable("Unknown EF_CUDA_SM value"); } diff --git a/llvm/lib/ObjectYAML/ELFEmitter.cpp b/llvm/lib/ObjectYAML/ELFEmitter.cpp index 6de87a8..bc5c68d 100644 --- a/llvm/lib/ObjectYAML/ELFEmitter.cpp +++ b/llvm/lib/ObjectYAML/ELFEmitter.cpp @@ -481,7 +481,11 @@ void ELFState<ELFT>::writeELFHeader(raw_ostream &OS) { Header.e_version = EV_CURRENT; Header.e_entry = Doc.Header.Entry; - Header.e_flags = Doc.Header.Flags; + if (Doc.Header.Flags) + Header.e_flags = *Doc.Header.Flags; + else + Header.e_flags = 0; + Header.e_ehsize = sizeof(Elf_Ehdr); if (Doc.Header.EPhOff) diff --git a/llvm/lib/ObjectYAML/ELFYAML.cpp b/llvm/lib/ObjectYAML/ELFYAML.cpp index 7fcabb68..c27339d 100644 --- a/llvm/lib/ObjectYAML/ELFYAML.cpp +++ b/llvm/lib/ObjectYAML/ELFYAML.cpp @@ -1160,7 +1160,7 @@ void MappingTraits<ELFYAML::FileHeader>::mapping(IO &IO, IO.mapOptional("ABIVersion", FileHdr.ABIVersion, Hex8(0)); IO.mapRequired("Type", FileHdr.Type); IO.mapOptional("Machine", FileHdr.Machine); - IO.mapOptional("Flags", FileHdr.Flags, ELFYAML::ELF_EF(0)); + IO.mapOptional("Flags", FileHdr.Flags); IO.mapOptional("Entry", FileHdr.Entry, Hex64(0)); IO.mapOptional("SectionHeaderStringTable", FileHdr.SectionHeaderStringTable); diff --git a/llvm/lib/ProfileData/MemProfReader.cpp b/llvm/lib/ProfileData/MemProfReader.cpp index 235b134..3fc0dbf 100644 --- a/llvm/lib/ProfileData/MemProfReader.cpp +++ b/llvm/lib/ProfileData/MemProfReader.cpp @@ -135,7 +135,7 @@ readMemInfoBlocksV3(const char *Ptr) { } llvm::SmallVector<std::pair<uint64_t, MemInfoBlock>> -readMemInfoBlocksV4(const char *Ptr) { +readMemInfoBlocksCommon(const char *Ptr, bool IsHistogramEncoded = false) { using namespace support; const uint64_t NumItemsToRead = @@ -145,27 +145,74 @@ readMemInfoBlocksV4(const char *Ptr) { for (uint64_t I = 0; I < NumItemsToRead; I++) { const uint64_t Id = endian::readNext<uint64_t, llvm::endianness::little, unaligned>(Ptr); - // We cheat a bit here and remove the const from cast to set the - // Histogram Pointer to newly allocated buffer. - MemInfoBlock MIB = *reinterpret_cast<const MemInfoBlock *>(Ptr); - // Only increment by size of MIB since readNext implicitly increments. - Ptr += sizeof(MemInfoBlock); + MemInfoBlock MIB; +#define READ_MIB_FIELD(FIELD) \ + MIB.FIELD = endian::readNext<decltype(MIB.FIELD), llvm::endianness::little, \ + unaligned>(Ptr) + + READ_MIB_FIELD(AllocCount); + READ_MIB_FIELD(TotalAccessCount); + READ_MIB_FIELD(MinAccessCount); + READ_MIB_FIELD(MaxAccessCount); + READ_MIB_FIELD(TotalSize); + READ_MIB_FIELD(MinSize); + READ_MIB_FIELD(MaxSize); + READ_MIB_FIELD(AllocTimestamp); + READ_MIB_FIELD(DeallocTimestamp); + READ_MIB_FIELD(TotalLifetime); + READ_MIB_FIELD(MinLifetime); + READ_MIB_FIELD(MaxLifetime); + READ_MIB_FIELD(AllocCpuId); + READ_MIB_FIELD(DeallocCpuId); + READ_MIB_FIELD(NumMigratedCpu); + READ_MIB_FIELD(NumLifetimeOverlaps); + READ_MIB_FIELD(NumSameAllocCpu); + READ_MIB_FIELD(NumSameDeallocCpu); + READ_MIB_FIELD(DataTypeId); + READ_MIB_FIELD(TotalAccessDensity); + READ_MIB_FIELD(MinAccessDensity); + READ_MIB_FIELD(MaxAccessDensity); + READ_MIB_FIELD(TotalLifetimeAccessDensity); + READ_MIB_FIELD(MinLifetimeAccessDensity); + READ_MIB_FIELD(MaxLifetimeAccessDensity); + READ_MIB_FIELD(AccessHistogramSize); + READ_MIB_FIELD(AccessHistogram); +#undef READ_MIB_FIELD if (MIB.AccessHistogramSize > 0) { + // The in-memory representation uses uint64_t for histogram entries. MIB.AccessHistogram = (uintptr_t)malloc(MIB.AccessHistogramSize * sizeof(uint64_t)); - } - - for (uint64_t J = 0; J < MIB.AccessHistogramSize; J++) { - ((uint64_t *)MIB.AccessHistogram)[J] = - endian::readNext<uint64_t, llvm::endianness::little, unaligned>(Ptr); + for (uint64_t J = 0; J < MIB.AccessHistogramSize; J++) { + if (!IsHistogramEncoded) { + ((uint64_t *)MIB.AccessHistogram)[J] = + endian::readNext<uint64_t, llvm::endianness::little, unaligned>( + Ptr); + } else { + // The encoded on-disk format (V5 onwards) uses uint16_t. + const uint16_t Val = + endian::readNext<uint16_t, llvm::endianness::little, unaligned>( + Ptr); + ((uint64_t *)MIB.AccessHistogram)[J] = decodeHistogramCount(Val); + } + } } Items.push_back({Id, MIB}); } return Items; } +llvm::SmallVector<std::pair<uint64_t, MemInfoBlock>> +readMemInfoBlocksV4(const char *Ptr) { + return readMemInfoBlocksCommon(Ptr); +} + +llvm::SmallVector<std::pair<uint64_t, MemInfoBlock>> +readMemInfoBlocksV5(const char *Ptr) { + return readMemInfoBlocksCommon(Ptr, /*IsHistogramEncoded=*/true); +} + CallStackMap readStackInfo(const char *Ptr) { using namespace support; @@ -658,6 +705,8 @@ RawMemProfReader::readMemInfoBlocks(const char *Ptr) { return readMemInfoBlocksV3(Ptr); if (MemprofRawVersion == 4ULL) return readMemInfoBlocksV4(Ptr); + if (MemprofRawVersion == 5ULL) + return readMemInfoBlocksV5(Ptr); llvm_unreachable( "Panic: Unsupported version number when reading MemInfoBlocks"); } diff --git a/llvm/lib/Remarks/RemarkLinker.cpp b/llvm/lib/Remarks/RemarkLinker.cpp index 0ca6217..b00419b 100644 --- a/llvm/lib/Remarks/RemarkLinker.cpp +++ b/llvm/lib/Remarks/RemarkLinker.cpp @@ -70,8 +70,8 @@ Error RemarkLinker::link(StringRef Buffer, Format RemarkFormat) { Expected<std::unique_ptr<RemarkParser>> MaybeParser = createRemarkParserFromMeta( RemarkFormat, Buffer, - PrependPath ? std::optional<StringRef>(StringRef(*PrependPath)) - : std::optional<StringRef>()); + PrependPath ? std::make_optional<StringRef>(*PrependPath) + : std::nullopt); if (!MaybeParser) return MaybeParser.takeError(); diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index 5e0b29f..46084c5 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -900,6 +900,30 @@ writeSignedDecimal (char *dst, int value) return dst; } +// Compute the ULP of the input using a definition from: +// Jean-Michel Muller. On the definition of ulp(x). [Research Report] RR-5504, +// LIP RR-2005-09, INRIA, LIP. 2005, pp.16. inria-00070503 +static APFloat harrisonUlp(const APFloat &X) { + const fltSemantics &Sem = X.getSemantics(); + switch (X.getCategory()) { + case APFloat::fcNaN: + return APFloat::getQNaN(Sem); + case APFloat::fcInfinity: + return APFloat::getInf(Sem); + case APFloat::fcZero: + return APFloat::getSmallest(Sem); + case APFloat::fcNormal: + break; + } + if (X.isDenormal() || X.isSmallestNormalized()) + return APFloat::getSmallest(Sem); + int Exp = ilogb(X); + if (X.getExactLog2() != INT_MIN) + Exp -= 1; + return scalbn(APFloat::getOne(Sem), Exp - (Sem.precision - 1), + APFloat::rmNearestTiesToEven); +} + namespace detail { /* Constructors. */ void IEEEFloat::initialize(const fltSemantics *ourSemantics) { @@ -5306,12 +5330,110 @@ Expected<APFloat::opStatus> DoubleAPFloat::convertFromString(StringRef S, return Ret; } +// The double-double lattice of values corresponds to numbers which obey: +// - abs(lo) <= 1/2 * ulp(hi) +// - roundTiesToEven(hi + lo) == hi +// +// nextUp must choose the smallest output > input that follows these rules. +// nexDown must choose the largest output < input that follows these rules. APFloat::opStatus DoubleAPFloat::next(bool nextDown) { assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); - auto Ret = Tmp.next(nextDown); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); - return Ret; + // nextDown(x) = -nextUp(-x) + if (nextDown) { + changeSign(); + APFloat::opStatus Result = next(/*nextDown=*/false); + changeSign(); + return Result; + } + switch (getCategory()) { + case fcInfinity: + // nextUp(+inf) = +inf + // nextUp(-inf) = -getLargest() + if (isNegative()) + makeLargest(true); + return opOK; + + case fcNaN: + // IEEE-754R 2008 6.2 Par 2: nextUp(sNaN) = qNaN. Set Invalid flag. + // IEEE-754R 2008 6.2: nextUp(qNaN) = qNaN. Must be identity so we do not + // change the payload. + if (getFirst().isSignaling()) { + // For consistency, propagate the sign of the sNaN to the qNaN. + makeNaN(false, isNegative(), nullptr); + return opInvalidOp; + } + return opOK; + + case fcZero: + // nextUp(pm 0) = +getSmallest() + makeSmallest(false); + return opOK; + + case fcNormal: + break; + } + + const APFloat &HiOld = getFirst(); + const APFloat &LoOld = getSecond(); + + APFloat NextLo = LoOld; + NextLo.next(/*nextDown=*/false); + + // We want to admit values where: + // 1. abs(Lo) <= ulp(Hi)/2 + // 2. Hi == RTNE(Hi + lo) + auto InLattice = [](const APFloat &Hi, const APFloat &Lo) { + return Hi + Lo == Hi; + }; + + // Check if (HiOld, nextUp(LoOld) is in the lattice. + if (InLattice(HiOld, NextLo)) { + // Yes, the result is (HiOld, nextUp(LoOld)). + Floats[1] = std::move(NextLo); + + // TODO: Because we currently rely on semPPCDoubleDoubleLegacy, our maximum + // value is defined to have exactly 106 bits of precision. This limitation + // results in semPPCDoubleDouble being unable to reach its maximum canonical + // value. + DoubleAPFloat Largest{*Semantics, uninitialized}; + Largest.makeLargest(/*Neg=*/false); + if (compare(Largest) == cmpGreaterThan) + makeInf(/*Neg=*/false); + + return opOK; + } + + // Now we need to handle the cases where (HiOld, nextUp(LoOld)) is not the + // correct result. We know the new hi component will be nextUp(HiOld) but our + // lattice rules make it a little ambiguous what the correct NextLo must be. + APFloat NextHi = HiOld; + NextHi.next(/*nextDown=*/false); + + // nextUp(getLargest()) == INFINITY + if (NextHi.isInfinity()) { + makeInf(/*Neg=*/false); + return opOK; + } + + // IEEE 754-2019 5.3.1: + // "If x is the negative number of least magnitude in x's format, nextUp(x) is + // -0." + if (NextHi.isZero()) { + makeZero(/*Neg=*/true); + return opOK; + } + + // abs(NextLo) must be <= ulp(NextHi)/2. We want NextLo to be as close to + // negative infinity as possible. + NextLo = neg(scalbn(harrisonUlp(NextHi), -1, rmTowardZero)); + if (!InLattice(NextHi, NextLo)) + // RTNE may mean that Lo must be < ulp(NextHi) / 2 so we bump NextLo. + NextLo.next(/*nextDown=*/false); + + Floats[0] = std::move(NextHi); + Floats[1] = std::move(NextLo); + + return opOK; } APFloat::opStatus diff --git a/llvm/lib/Support/BLAKE3/CMakeLists.txt b/llvm/lib/Support/BLAKE3/CMakeLists.txt index eae2b02..90311ae 100644 --- a/llvm/lib/Support/BLAKE3/CMakeLists.txt +++ b/llvm/lib/Support/BLAKE3/CMakeLists.txt @@ -26,7 +26,8 @@ endmacro() if (CAN_USE_ASSEMBLER) if (MSVC) check_symbol_exists(_M_X64 "" IS_X64) - if (IS_X64) + check_symbol_exists(_M_ARM64EC "" IS_ARM64EC) + if (IS_X64 AND NOT IS_ARM64EC) enable_language(ASM_MASM) set(LLVM_BLAKE3_ASM_FILES blake3_sse2_x86-64_windows_msvc.asm diff --git a/llvm/lib/Support/FileCollector.cpp b/llvm/lib/Support/FileCollector.cpp index 29436f8..edb5313 100644 --- a/llvm/lib/Support/FileCollector.cpp +++ b/llvm/lib/Support/FileCollector.cpp @@ -313,5 +313,6 @@ private: IntrusiveRefCntPtr<vfs::FileSystem> FileCollector::createCollectorVFS(IntrusiveRefCntPtr<vfs::FileSystem> BaseFS, std::shared_ptr<FileCollector> Collector) { - return new FileCollectorFileSystem(std::move(BaseFS), std::move(Collector)); + return makeIntrusiveRefCnt<FileCollectorFileSystem>(std::move(BaseFS), + std::move(Collector)); } diff --git a/llvm/lib/Support/VirtualFileSystem.cpp b/llvm/lib/Support/VirtualFileSystem.cpp index e489282..5d42488 100644 --- a/llvm/lib/Support/VirtualFileSystem.cpp +++ b/llvm/lib/Support/VirtualFileSystem.cpp @@ -397,7 +397,8 @@ void RealFileSystem::printImpl(raw_ostream &OS, PrintType Type, } IntrusiveRefCntPtr<FileSystem> vfs::getRealFileSystem() { - static IntrusiveRefCntPtr<FileSystem> FS(new RealFileSystem(true)); + static IntrusiveRefCntPtr<FileSystem> FS = + makeIntrusiveRefCnt<RealFileSystem>(true); return FS; } @@ -2217,9 +2218,9 @@ RedirectingFileSystem::create(std::unique_ptr<MemoryBuffer> Buffer, std::unique_ptr<RedirectingFileSystem> RedirectingFileSystem::create( ArrayRef<std::pair<std::string, std::string>> RemappedFiles, - bool UseExternalNames, FileSystem &ExternalFS) { + bool UseExternalNames, llvm::IntrusiveRefCntPtr<FileSystem> ExternalFS) { std::unique_ptr<RedirectingFileSystem> FS( - new RedirectingFileSystem(&ExternalFS)); + new RedirectingFileSystem(ExternalFS)); FS->UseExternalNames = UseExternalNames; StringMap<RedirectingFileSystem::Entry *> Entries; @@ -2228,7 +2229,7 @@ std::unique_ptr<RedirectingFileSystem> RedirectingFileSystem::create( SmallString<128> From = StringRef(Mapping.first); SmallString<128> To = StringRef(Mapping.second); { - auto EC = ExternalFS.makeAbsolute(From); + auto EC = ExternalFS->makeAbsolute(From); (void)EC; assert(!EC && "Could not make absolute path"); } @@ -2250,7 +2251,7 @@ std::unique_ptr<RedirectingFileSystem> RedirectingFileSystem::create( } assert(Parent && "File without a directory?"); { - auto EC = ExternalFS.makeAbsolute(To); + auto EC = ExternalFS->makeAbsolute(To); (void)EC; assert(!EC && "Could not make absolute path"); } diff --git a/llvm/lib/Support/Windows/Threading.inc b/llvm/lib/Support/Windows/Threading.inc index 8dd7c88..b11f216 100644 --- a/llvm/lib/Support/Windows/Threading.inc +++ b/llvm/lib/Support/Windows/Threading.inc @@ -136,6 +136,7 @@ HMODULE loadSystemModuleSecure(LPCWSTR lpModuleName) { } // namespace llvm::sys::windows SetThreadPriorityResult llvm::set_thread_priority(ThreadPriority Priority) { +#ifdef THREAD_POWER_THROTTLING_CURRENT_VERSION HMODULE kernelM = llvm::sys::windows::loadSystemModuleSecure(L"kernel32.dll"); if (kernelM) { // SetThreadInformation is only available on Windows 8 and later. Since we @@ -166,6 +167,7 @@ SetThreadPriorityResult llvm::set_thread_priority(ThreadPriority Priority) { : 0); } } +#endif // https://docs.microsoft.com/en-us/windows/desktop/api/processthreadsapi/nf-processthreadsapi-setthreadpriority // Begin background processing mode. The system lowers the resource scheduling diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp index 509cbb0..e8d3161 100644 --- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp @@ -813,8 +813,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { } } - if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() || - F.hasLocalLinkage() || + if (!F.hasFnAttribute(Attribute::HybridPatchable) || + F.isDeclarationForLinker() || F.hasLocalLinkage() || F.getName().ends_with(HybridPatchableTargetSuffix)) continue; @@ -857,7 +857,7 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { SetVector<GlobalValue *> DirectCalledFns; for (Function &F : Mod) - if (!F.isDeclaration() && + if (!F.isDeclarationForLinker() && F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) processFunction(F, DirectCalledFns, FnsMap); @@ -869,7 +869,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { }; SmallVector<ThunkInfo> ThunkMapping; for (Function &F : Mod) { - if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) && + if (!F.isDeclarationForLinker() && + (!F.hasLocalLinkage() || F.hasAddressTaken()) && F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) { if (!F.hasComdat()) @@ -959,7 +960,7 @@ bool AArch64Arm64ECCallLowering::processFunction( // unprototyped functions in C) if (Function *F = CB->getCalledFunction()) { if (!LowerDirectToIndirect || F->hasLocalLinkage() || - F->isIntrinsic() || !F->isDeclaration()) + F->isIntrinsic() || !F->isDeclarationForLinker()) continue; DirectCalledFns.insert(F); diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td index ca09598..99f0af5 100644 --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -39,8 +39,8 @@ let Predicates = [HasDotProd] in { def ext_addv_to_udot_addv : GICombineRule< (defs root:$root, ext_addv_to_udot_addv_matchinfo:$matchinfo), (match (wip_match_opcode G_VECREDUCE_ADD):$root, - [{ return matchExtAddvToUdotAddv(*${root}, MRI, STI, ${matchinfo}); }]), - (apply [{ applyExtAddvToUdotAddv(*${root}, MRI, B, Observer, STI, ${matchinfo}); }]) + [{ return matchExtAddvToDotAddv(*${root}, MRI, STI, ${matchinfo}); }]), + (apply [{ applyExtAddvToDotAddv(*${root}, MRI, B, Observer, STI, ${matchinfo}); }]) >; } @@ -62,8 +62,10 @@ class push_opcode_through_ext<Instruction opcode, Instruction extOpcode> : GICom def push_sub_through_zext : push_opcode_through_ext<G_SUB, G_ZEXT>; def push_add_through_zext : push_opcode_through_ext<G_ADD, G_ZEXT>; +def push_mul_through_zext : push_opcode_through_ext<G_MUL, G_ZEXT>; def push_sub_through_sext : push_opcode_through_ext<G_SUB, G_SEXT>; def push_add_through_sext : push_opcode_through_ext<G_ADD, G_SEXT>; +def push_mul_through_sext : push_opcode_through_ext<G_MUL, G_SEXT>; def AArch64PreLegalizerCombiner: GICombiner< "AArch64PreLegalizerCombinerImpl", [all_combines, @@ -75,8 +77,10 @@ def AArch64PreLegalizerCombiner: GICombiner< ext_uaddv_to_uaddlv, push_sub_through_zext, push_add_through_zext, + push_mul_through_zext, push_sub_through_sext, - push_add_through_sext]> { + push_add_through_sext, + push_mul_through_sext]> { let CombineAllMethodName = "tryCombineAllImpl"; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 77c0773..2b6ea86 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -162,10 +162,10 @@ static cl::opt<bool> UseFEATCPACodegen( cl::init(false)); /// Value type used for condition codes. -static const MVT MVT_CC = MVT::i32; +constexpr MVT CondCodeVT = MVT::i32; /// Value type used for NZCV flags. -static constexpr MVT FlagsVT = MVT::i32; +constexpr MVT FlagsVT = MVT::i32; static const MCPhysReg GPRArgRegs[] = {AArch64::X0, AArch64::X1, AArch64::X2, AArch64::X3, AArch64::X4, AArch64::X5, @@ -3472,6 +3472,12 @@ static void changeVectorFPCCToAArch64CC(ISD::CondCode CC, } } +/// Like SelectionDAG::getCondCode(), but for AArch64 condition codes. +static SDValue getCondCode(SelectionDAG &DAG, AArch64CC::CondCode CC) { + // TODO: Should be TargetConstant (need to s/imm/timm in patterns). + return DAG.getConstant(CC, SDLoc(), CondCodeVT); +} + static bool isLegalArithImmed(uint64_t C) { // Matches AArch64DAGToDAGISel::SelectArithImmed(). bool IsLegal = (C >> 12 == 0) || ((C & 0xFFFULL) == 0 && C >> 24 == 0); @@ -3678,7 +3684,7 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS, if (Opcode == 0) Opcode = AArch64ISD::CCMP; - SDValue Condition = DAG.getConstant(Predicate, DL, MVT_CC); + SDValue Condition = getCondCode(DAG, Predicate); AArch64CC::CondCode InvOutCC = AArch64CC::getInvertedCondCode(OutCC); unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvOutCC); SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32); @@ -4075,7 +4081,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, Cmp = emitComparison(LHS, RHS, CC, DL, DAG); AArch64CC = changeIntCCToAArch64CC(CC); } - AArch64cc = DAG.getConstant(AArch64CC, DL, MVT_CC); + AArch64cc = getCondCode(DAG, AArch64CC); return Cmp; } @@ -4195,7 +4201,7 @@ SDValue AArch64TargetLowering::LowerXOR(SDValue Op, SelectionDAG &DAG) const { AArch64CC::CondCode CC; SDValue Value, Overflow; std::tie(Value, Overflow) = getAArch64XALUOOp(CC, Sel.getValue(0), DAG); - SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, getInvertedCondCode(CC)); return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal, CCVal, Overflow); } @@ -4274,8 +4280,8 @@ static SDValue carryFlagToValue(SDValue Glue, EVT VT, SelectionDAG &DAG, SDLoc DL(Glue); SDValue Zero = DAG.getConstant(0, DL, VT); SDValue One = DAG.getConstant(1, DL, VT); - unsigned Cond = Invert ? AArch64CC::LO : AArch64CC::HS; - SDValue CC = DAG.getConstant(Cond, DL, MVT::i32); + AArch64CC::CondCode Cond = Invert ? AArch64CC::LO : AArch64CC::HS; + SDValue CC = getCondCode(DAG, Cond); return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Glue); } @@ -4285,7 +4291,7 @@ static SDValue overflowFlagToValue(SDValue Glue, EVT VT, SelectionDAG &DAG) { SDLoc DL(Glue); SDValue Zero = DAG.getConstant(0, DL, VT); SDValue One = DAG.getConstant(1, DL, VT); - SDValue CC = DAG.getConstant(AArch64CC::VS, DL, MVT::i32); + SDValue CC = getCondCode(DAG, AArch64CC::VS); return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Glue); } @@ -4334,7 +4340,7 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { // We use an inverted condition, because the conditional select is inverted // too. This will allow it to be selected to a single instruction: // CSINC Wd, WZR, WZR, invert(cond). - SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, getInvertedCondCode(CC)); Overflow = DAG.getNode(AArch64ISD::CSEL, DL, MVT::i32, FVal, TVal, CCVal, Overflow); @@ -7124,8 +7130,7 @@ SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const { SDValue Cmp = DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT), Op.getOperand(0), DAG.getConstant(0, DL, VT)); return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg, - DAG.getConstant(AArch64CC::PL, DL, MVT::i32), - Cmp.getValue(1)); + getCondCode(DAG, AArch64CC::PL), Cmp.getValue(1)); } static SDValue LowerBRCOND(SDValue Op, SelectionDAG &DAG) { @@ -7136,7 +7141,7 @@ static SDValue LowerBRCOND(SDValue Op, SelectionDAG &DAG) { AArch64CC::CondCode CC; if (SDValue Cmp = emitConjunction(DAG, Cond, CC)) { SDLoc DL(Op); - SDValue CCVal = DAG.getConstant(CC, DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, CC); return DAG.getNode(AArch64ISD::BRCOND, DL, MVT::Other, Chain, Dest, CCVal, Cmp); } @@ -8952,6 +8957,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, bool &IsTailCall = CLI.IsTailCall; CallingConv::ID &CallConv = CLI.CallConv; bool IsVarArg = CLI.IsVarArg; + const CallBase *CB = CLI.CB; MachineFunction &MF = DAG.getMachineFunction(); MachineFunction::CallSiteInfo CSInfo; @@ -8991,6 +8997,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, *DAG.getContext()); RetCCInfo.AnalyzeCallResult(Ins, RetCC); + // Set type id for call site info. + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + CSInfo = MachineFunction::CallSiteInfo(*CB); + // Check callee args/returns for SVE registers and set calling convention // accordingly. if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) { @@ -10570,7 +10580,7 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { if (CC == ISD::SETNE) OFCC = getInvertedCondCode(OFCC); - SDValue CCVal = DAG.getConstant(OFCC, DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, OFCC); return DAG.getNode(AArch64ISD::BRCOND, DL, MVT::Other, Chain, Dest, CCVal, Overflow); @@ -10643,7 +10653,7 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { AArch64CC::isValidCBCond(changeIntCCToAArch64CC(CC)) && ProduceNonFlagSettingCondBr) { SDValue Cond = - DAG.getTargetConstant(changeIntCCToAArch64CC(CC), DL, MVT::i32); + DAG.getTargetConstant(changeIntCCToAArch64CC(CC), DL, CondCodeVT); return DAG.getNode(AArch64ISD::CB, DL, MVT::Other, Chain, Cond, LHS, RHS, Dest); } @@ -10662,11 +10672,11 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { SDValue Cmp = emitComparison(LHS, RHS, CC, DL, DAG); AArch64CC::CondCode CC1, CC2; changeFPCCToAArch64CC(CC, CC1, CC2); - SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32); + SDValue CC1Val = getCondCode(DAG, CC1); SDValue BR1 = DAG.getNode(AArch64ISD::BRCOND, DL, MVT::Other, Chain, Dest, CC1Val, Cmp); if (CC2 != AArch64CC::AL) { - SDValue CC2Val = DAG.getConstant(CC2, DL, MVT::i32); + SDValue CC2Val = getCondCode(DAG, CC2); return DAG.getNode(AArch64ISD::BRCOND, DL, MVT::Other, BR1, Dest, CC2Val, Cmp); } @@ -11155,7 +11165,7 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { if (CC2 == AArch64CC::AL) { changeFPCCToAArch64CC(ISD::getSetCCInverse(CC, LHS.getValueType()), CC1, CC2); - SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32); + SDValue CC1Val = getCondCode(DAG, CC1); // Note that we inverted the condition above, so we reverse the order of // the true and false operands here. This will allow the setcc to be @@ -11168,11 +11178,11 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { // of the first as the RHS. We're effectively OR'ing the two CC's together. // FIXME: It would be nice if we could match the two CSELs to two CSINCs. - SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32); + SDValue CC1Val = getCondCode(DAG, CC1); SDValue CS1 = DAG.getNode(AArch64ISD::CSEL, DL, VT, TVal, FVal, CC1Val, Cmp); - SDValue CC2Val = DAG.getConstant(CC2, DL, MVT::i32); + SDValue CC2Val = getCondCode(DAG, CC2); Res = DAG.getNode(AArch64ISD::CSEL, DL, VT, TVal, CS1, CC2Val, Cmp); } return IsStrict ? DAG.getMergeValues({Res, Cmp.getValue(1)}, DL) : Res; @@ -11200,8 +11210,7 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op, ISD::CondCode Cond = cast<CondCodeSDNode>(Op.getOperand(3))->get(); ISD::CondCode CondInv = ISD::getSetCCInverse(Cond, VT); - SDValue CCVal = - DAG.getConstant(changeIntCCToAArch64CC(CondInv), DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, changeIntCCToAArch64CC(CondInv)); // Inputs are swapped because the condition is inverted. This will allow // matching with a single CSINC instruction. return DAG.getNode(AArch64ISD::CSEL, DL, OpVT, FVal, TVal, CCVal, @@ -11355,18 +11364,6 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal); ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal); ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS); - // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform - // into (OR (ASR lhs, N-1), 1), which requires less instructions for the - // supported types. - if (CC == ISD::SETGT && RHSC && RHSC->isAllOnes() && CTVal && CFVal && - CTVal->isOne() && CFVal->isAllOnes() && - LHS.getValueType() == TVal.getValueType()) { - EVT VT = LHS.getValueType(); - SDValue Shift = - DAG.getNode(ISD::SRA, DL, VT, LHS, - DAG.getConstant(VT.getSizeInBits() - 1, DL, VT)); - return DAG.getNode(ISD::OR, DL, VT, Shift, DAG.getConstant(1, DL, VT)); - } // Check for SMAX(lhs, 0) and SMIN(lhs, 0) patterns. // (SELECT_CC setgt, lhs, 0, lhs, 0) -> (BIC lhs, (SRA lhs, typesize-1)) @@ -11386,6 +11383,22 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( return DAG.getNode(ISD::AND, DL, VT, LHS, Shift); } + // Canonicalise absolute difference patterns: + // select_cc lhs, rhs, sub(lhs, rhs), sub(rhs, lhs), cc -> + // select_cc lhs, rhs, sub(lhs, rhs), neg(sub(lhs, rhs)), cc + // + // select_cc lhs, rhs, sub(rhs, lhs), sub(lhs, rhs), cc -> + // select_cc lhs, rhs, neg(sub(lhs, rhs)), sub(lhs, rhs), cc + // The second forms can be matched into subs+cneg. + if (TVal.getOpcode() == ISD::SUB && FVal.getOpcode() == ISD::SUB) { + if (TVal.getOperand(0) == LHS && TVal.getOperand(1) == RHS && + FVal.getOperand(0) == RHS && FVal.getOperand(1) == LHS) + FVal = DAG.getNegative(TVal, DL, TVal.getValueType()); + else if (TVal.getOperand(0) == RHS && TVal.getOperand(1) == LHS && + FVal.getOperand(0) == LHS && FVal.getOperand(1) == RHS) + TVal = DAG.getNegative(FVal, DL, FVal.getValueType()); + } + unsigned Opcode = AArch64ISD::CSEL; // If both the TVal and the FVal are constants, see if we can swap them in @@ -11556,13 +11569,13 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( } // Emit first, and possibly only, CSEL. - SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32); + SDValue CC1Val = getCondCode(DAG, CC1); SDValue CS1 = DAG.getNode(AArch64ISD::CSEL, DL, VT, TVal, FVal, CC1Val, Cmp); // If we need a second CSEL, emit it, using the output of the first as the // RHS. We're effectively OR'ing the two CC's together. if (CC2 != AArch64CC::AL) { - SDValue CC2Val = DAG.getConstant(CC2, DL, MVT::i32); + SDValue CC2Val = getCondCode(DAG, CC2); return DAG.getNode(AArch64ISD::CSEL, DL, VT, TVal, CS1, CC2Val, Cmp); } @@ -11664,7 +11677,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, AArch64CC::CondCode OFCC; SDValue Value, Overflow; std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, CCVal.getValue(0), DAG); - SDValue CCVal = DAG.getConstant(OFCC, DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, OFCC); return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal, CCVal, Overflow); @@ -12504,10 +12517,10 @@ static AArch64CC::CondCode parseConstraintCode(llvm::StringRef Constraint) { /// WZR, invert(<cond>)'. static SDValue getSETCC(AArch64CC::CondCode CC, SDValue NZCV, const SDLoc &DL, SelectionDAG &DAG) { - return DAG.getNode( - AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(getInvertedCondCode(CC), DL, MVT::i32), NZCV); + return DAG.getNode(AArch64ISD::CSINC, DL, MVT::i32, + DAG.getConstant(0, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i32), + getCondCode(DAG, getInvertedCondCode(CC)), NZCV); } // Lower @cc flag output via getSETCC. @@ -18678,7 +18691,7 @@ AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, Created.push_back(Cmp.getNode()); Created.push_back(And.getNode()); } else { - SDValue CCVal = DAG.getConstant(AArch64CC::MI, DL, MVT_CC); + SDValue CCVal = getCondCode(DAG, AArch64CC::MI); SDVTList VTs = DAG.getVTList(VT, FlagsVT); SDValue Negs = DAG.getNode(AArch64ISD::SUBS, DL, VTs, Zero, N0); @@ -19550,11 +19563,11 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) { if (N->getOpcode() == ISD::AND) { AArch64CC::CondCode InvCC0 = AArch64CC::getInvertedCondCode(CC0); - Condition = DAG.getConstant(InvCC0, DL, MVT_CC); + Condition = getCondCode(DAG, InvCC0); NZCV = AArch64CC::getNZCVToSatisfyCondCode(CC1); } else { AArch64CC::CondCode InvCC1 = AArch64CC::getInvertedCondCode(CC1); - Condition = DAG.getConstant(CC0, DL, MVT_CC); + Condition = getCondCode(DAG, CC0); NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvCC1); } @@ -19575,8 +19588,7 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) { Cmp1.getOperand(1), NZCVOp, Condition, Cmp0); } return DAG.getNode(AArch64ISD::CSEL, DL, VT, CSel0.getOperand(0), - CSel0.getOperand(1), DAG.getConstant(CC1, DL, MVT::i32), - CCmp); + CSel0.getOperand(1), getCondCode(DAG, CC1), CCmp); } static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, @@ -19781,7 +19793,7 @@ static SDValue performANDSETCCCombine(SDNode *N, SDLoc DL(N); return DAG.getNode(AArch64ISD::CSINC, DL, VT, DAG.getConstant(0, DL, VT), DAG.getConstant(0, DL, VT), - DAG.getConstant(InvertedCC, DL, MVT::i32), Cmp); + getCondCode(DAG, InvertedCC), Cmp); } } return SDValue(); @@ -20772,7 +20784,7 @@ static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) { "Unexpected constant value"); SDValue NewNode = DAG.getNode(ISD::ADD, DL, VT, RHS, SDValue(CTVal, 0)); - SDValue CCVal = DAG.getConstant(AArch64CC, DL, MVT::i32); + SDValue CCVal = getCondCode(DAG, AArch64CC); SDValue Cmp = LHS.getOperand(3); return DAG.getNode(AArch64ISD::CSINC, DL, VT, NewNode, RHS, CCVal, Cmp); @@ -20958,7 +20970,7 @@ static SDValue foldADCToCINC(SDNode *N, SelectionDAG &DAG) { SDLoc DL(N); // (CINC x cc cond) <=> (CSINC x x !cc cond) - SDValue CC = DAG.getConstant(AArch64CC::LO, DL, MVT::i32); + SDValue CC = getCondCode(DAG, AArch64CC::LO); return DAG.getNode(AArch64ISD::CSINC, DL, VT, LHS, LHS, CC, Cond); } @@ -22031,7 +22043,7 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, // Convert CC to integer based on requested condition. // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare. - SDValue CC = DAG.getConstant(getInvertedCondCode(Cond), DL, MVT::i32); + SDValue CC = getCondCode(DAG, getInvertedCondCode(Cond)); SDValue Res = DAG.getNode(AArch64ISD::CSEL, DL, OutVT, FVal, TVal, CC, Test); return DAG.getZExtOrTrunc(Res, DL, VT); } @@ -24112,6 +24124,60 @@ static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG, Store->getMemOperand()); } +// Combine store (fp_to_int X) to use vector semantics around the conversion +// when NEON is available. This allows us to store the in-vector result directly +// without transferring the result into a GPR in the process. +static SDValue combineStoreValueFPToInt(StoreSDNode *ST, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + // Limit to post-legalization in order to avoid peeling truncating stores. + if (DCI.isBeforeLegalize()) + return SDValue(); + if (!Subtarget->isNeonAvailable()) + return SDValue(); + // Source operand is already a vector. + SDValue Value = ST->getValue(); + if (Value.getValueType().isVector()) + return SDValue(); + + // Look through potential assertions. + while (Value->isAssert()) + Value = Value.getOperand(0); + + if (Value.getOpcode() != ISD::FP_TO_SINT && + Value.getOpcode() != ISD::FP_TO_UINT) + return SDValue(); + if (!Value->hasOneUse()) + return SDValue(); + + SDValue FPSrc = Value.getOperand(0); + EVT SrcVT = FPSrc.getValueType(); + if (SrcVT != MVT::f32 && SrcVT != MVT::f64) + return SDValue(); + + // No support for assignments such as i64 = fp_to_sint i32 + EVT VT = Value.getSimpleValueType(); + if (VT != SrcVT.changeTypeToInteger()) + return SDValue(); + + // Create a 128-bit element vector to avoid widening. The floating point + // conversion is transformed into a single element conversion via a pattern. + unsigned NumElements = 128 / SrcVT.getFixedSizeInBits(); + EVT VecSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumElements); + EVT VecDstVT = VecSrcVT.changeTypeToInteger(); + SDLoc DL(ST); + SDValue VecFP = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecSrcVT, FPSrc); + SDValue VecConv = DAG.getNode(Value.getOpcode(), DL, VecDstVT, VecFP); + + SDValue Zero = DAG.getVectorIdxConstant(0, DL); + SDValue Extracted = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, VecConv, Zero); + + DCI.CombineTo(ST->getValue().getNode(), Extracted); + return SDValue(ST, 0); +} + bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) { return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) || (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) || @@ -24194,6 +24260,9 @@ static SDValue performSTORECombine(SDNode *N, const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDLoc DL(ST); + if (SDValue Res = combineStoreValueFPToInt(ST, DCI, DAG, Subtarget)) + return Res; + auto hasValidElementTypeForFPTruncStore = [](EVT VT) { EVT EltVT = VT.getVectorElementType(); return EltVT == MVT::f32 || EltVT == MVT::f64; @@ -25015,10 +25084,9 @@ static SDValue performBRCONDCombine(SDNode *N, auto CSelCC = getCSETCondCode(CSel); if (CSelCC) { SDLoc DL(N); - return DAG.getNode( - N->getOpcode(), DL, N->getVTList(), Chain, Dest, - DAG.getConstant(getInvertedCondCode(*CSelCC), DL, MVT::i32), - CSel.getOperand(3)); + return DAG.getNode(N->getOpcode(), DL, N->getVTList(), Chain, Dest, + getCondCode(DAG, getInvertedCondCode(*CSelCC)), + CSel.getOperand(3)); } } @@ -25159,7 +25227,7 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) { SDLoc DL(Op); EVT VT = Op->getValueType(0); - SDValue CCValue = DAG.getConstant(CC, DL, MVT::i32); + SDValue CCValue = getCondCode(DAG, CC); return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond); } @@ -25236,8 +25304,7 @@ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) { SDValue TValReassoc = Reassociate(TReassocOp, 0); SDValue FValReassoc = Reassociate(FReassocOp, 1); return DAG.getNode(AArch64ISD::CSEL, SDLoc(N), VT, TValReassoc, FValReassoc, - DAG.getConstant(NewCC, SDLoc(N->getOperand(2)), MVT_CC), - NewCmp.getValue(1)); + getCondCode(DAG, NewCC), NewCmp.getValue(1)); }; auto CC = static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2)); @@ -25378,8 +25445,7 @@ static SDValue performCSELCombine(SDNode *N, SDValue Sub = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(), Cond.getOperand(1), Cond.getOperand(0)); return DAG.getNode(AArch64ISD::CSEL, DL, N->getVTList(), N->getOperand(0), - N->getOperand(1), - DAG.getConstant(NewCond, DL, MVT::i32), + N->getOperand(1), getCondCode(DAG, NewCond), Sub.getValue(1)); } } @@ -25479,10 +25545,9 @@ static SDValue performSETCCCombine(SDNode *N, auto NewCond = getInvertedCondCode(OldCond); // csel 0, 1, !cond, X - SDValue CSEL = - DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), LHS.getOperand(0), - LHS.getOperand(1), DAG.getConstant(NewCond, DL, MVT::i32), - LHS.getOperand(3)); + SDValue CSEL = DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), + LHS.getOperand(0), LHS.getOperand(1), + getCondCode(DAG, NewCond), LHS.getOperand(3)); return DAG.getZExtOrTrunc(CSEL, DL, VT); } @@ -25552,8 +25617,7 @@ static SDValue performFlagSettingCombine(SDNode *N, // If the flag result isn't used, convert back to a generic opcode. if (!N->hasAnyUseOfValue(1)) { SDValue Res = DCI.DAG.getNode(GenericOpcode, DL, VT, N->ops()); - return DCI.DAG.getMergeValues({Res, DCI.DAG.getConstant(0, DL, MVT::i32)}, - DL); + return DCI.CombineTo(N, Res, SDValue(N, 1)); } // Combine identical generic nodes into this node, re-using the result. @@ -26935,10 +26999,10 @@ static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) { SDValue A = DAG.getNode( AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, FlagsVT, MVT::Other), N->getOperand(0), DAG.getConstant(Register, DL, MVT::i32)); - SDValue B = DAG.getNode( - AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1)); + SDValue B = DAG.getNode(AArch64ISD::CSINC, DL, MVT::i32, + DAG.getConstant(0, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i32), + getCondCode(DAG, AArch64CC::NE), A.getValue(1)); return DAG.getMergeValues( {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL); } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 07cacfa..ac31236 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -448,8 +448,13 @@ def SDTBinaryArithWithFlagsInOut : SDTypeProfile<2, 3, SDTCisVT<1, FlagsVT>, SDTCisVT<4, FlagsVT>]>; +// Value type used for condition codes. +// Should be kept in sync with its C++ counterpart. +defvar CondCodeVT = i32; + def SDT_AArch64Brcond : SDTypeProfile<0, 3, - [SDTCisVT<0, OtherVT>, SDTCisVT<1, i32>, + [SDTCisVT<0, OtherVT>, + SDTCisVT<1, CondCodeVT>, SDTCisVT<2, FlagsVT>]>; def SDT_AArch64cbz : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisVT<1, OtherVT>]>; def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>, @@ -458,22 +463,22 @@ def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>, def SDT_AArch64CSel : SDTypeProfile<1, 4, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, - SDTCisInt<3>, + SDTCisVT<3, CondCodeVT>, SDTCisVT<4, FlagsVT>]>; def SDT_AArch64CCMP : SDTypeProfile<1, 5, [SDTCisVT<0, FlagsVT>, SDTCisInt<1>, SDTCisSameAs<1, 2>, SDTCisInt<3>, - SDTCisInt<4>, - SDTCisVT<5, i32>]>; + SDTCisVT<4, CondCodeVT>, + SDTCisVT<5, FlagsVT>]>; def SDT_AArch64FCCMP : SDTypeProfile<1, 5, [SDTCisVT<0, FlagsVT>, SDTCisFP<1>, SDTCisSameAs<1, 2>, SDTCisInt<3>, - SDTCisInt<4>, - SDTCisVT<5, i32>]>; + SDTCisVT<4, CondCodeVT>, + SDTCisVT<5, FlagsVT>]>; def SDT_AArch64FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, FlagsVT>, SDTCisFP<1>, SDTCisSameAs<2, 1>]>; @@ -546,7 +551,8 @@ def SDT_AArch64TBL : SDTypeProfile<1, 2, [ ]>; def SDT_AArch64cb : SDTypeProfile<0, 4, - [SDTCisVT<0, i32>, SDTCisInt<1>, SDTCisInt<2>, + [SDTCisVT<0, CondCodeVT>, + SDTCisInt<1>, SDTCisInt<2>, SDTCisVT<3, OtherVT>]>; // non-extending masked load fragment. @@ -6668,6 +6674,15 @@ def : Pat<(f16 (any_uint_to_fp (i32 (any_fp_to_uint f16:$Rn)))), (UCVTFv1i16 (f16 (FCVTZUv1f16 f16:$Rn)))>; } +def : Pat<(v4i32 (any_fp_to_sint (v4f32 (scalar_to_vector (f32 FPR32:$src))))), + (v4i32 (INSERT_SUBREG (IMPLICIT_DEF), (i32 (FCVTZSv1i32 (f32 FPR32:$src))), ssub))>; +def : Pat<(v4i32 (any_fp_to_uint (v4f32 (scalar_to_vector (f32 FPR32:$src))))), + (v4i32 (INSERT_SUBREG (IMPLICIT_DEF), (i32 (FCVTZUv1i32 (f32 FPR32:$src))), ssub))>; +def : Pat<(v2i64 (any_fp_to_sint (v2f64 (scalar_to_vector (f64 FPR64:$src))))), + (v2i64 (INSERT_SUBREG (IMPLICIT_DEF), (i64 (FCVTZSv1i64 (f64 FPR64:$src))), dsub))>; +def : Pat<(v2i64 (any_fp_to_uint (v2f64 (scalar_to_vector (f64 FPR64:$src))))), + (v2i64 (INSERT_SUBREG (IMPLICIT_DEF), (i64 (FCVTZUv1i64 (f64 FPR64:$src))), dsub))>; + // int -> float conversion of value in lane 0 of simd vector should use // correct cvtf variant to avoid costly fpr <-> gpr register transfers. def : Pat<(f32 (sint_to_fp (i32 (vector_extract (v4i32 FPR128:$Rn), (i64 0))))), diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 40f49da..e1adc0b 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -270,6 +270,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, const Function *Callee) const { SMECallAttrs CallAttrs(*Caller, *Callee); + // Never inline a function explicitly marked as being streaming, + // into a non-streaming function. Assume it was marked as streaming + // for a reason. + if (CallAttrs.caller().hasNonStreamingInterfaceAndBody() && + CallAttrs.callee().hasStreamingInterfaceOrBody()) + return false; + // When inlining, we should consider the body of the function, not the // interface. if (CallAttrs.callee().hasStreamingBody()) { @@ -4905,14 +4912,17 @@ void AArch64TTIImpl::getUnrollingPreferences( // Disable partial & runtime unrolling on -Os. UP.PartialOptSizeThreshold = 0; - // No need to unroll auto-vectorized loops - if (findStringMetadataForLoop(L, "llvm.loop.isvectorized")) - return; - // Scan the loop: don't unroll loops with calls as this could prevent - // inlining. + // inlining. Don't unroll auto-vectorized loops either, though do allow + // unrolling of the scalar remainder. + bool IsVectorized = getBooleanLoopAttribute(L, "llvm.loop.isvectorized"); for (auto *BB : L->getBlocks()) { for (auto &I : *BB) { + // Both auto-vectorized loops and the scalar remainder have the + // isvectorized attribute, so differentiate between them by the presence + // of vector instructions. + if (IsVectorized && I.getType()->isVectorTy()) + return; if (isa<CallBase>(I)) { if (isa<CallInst>(I) || isa<InvokeInst>(I)) if (const Function *F = cast<CallBase>(I).getCalledFunction()) diff --git a/llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp b/llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp index 0b79850..1a15075 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp @@ -50,8 +50,10 @@ bool AArch64GISelUtils::isCMN(const MachineInstr *MaybeSub, // // %sub = G_SUB 0, %y // %cmp = G_ICMP eq/ne, %z, %sub + // or with signed comparisons with the no-signed-wrap flag set if (!MaybeSub || MaybeSub->getOpcode() != TargetOpcode::G_SUB || - !CmpInst::isEquality(Pred)) + (!CmpInst::isEquality(Pred) && + !(CmpInst::isSigned(Pred) && MaybeSub->getFlag(MachineInstr::NoSWrap)))) return false; auto MaybeZero = getIConstantVRegValWithLookThrough(MaybeSub->getOperand(1).getReg(), MRI); diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp index 1381a9b..d905692 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -1810,7 +1810,7 @@ bool AArch64InstructionSelector::selectCompareBranchFedByICmp( // Couldn't optimize. Emit a compare + a Bcc. MachineBasicBlock *DestMBB = I.getOperand(1).getMBB(); - auto PredOp = ICmp.getOperand(1); + auto &PredOp = ICmp.getOperand(1); emitIntegerCompare(ICmp.getOperand(2), ICmp.getOperand(3), PredOp, MIB); const AArch64CC::CondCode CC = changeICMPPredToAArch64CC( static_cast<CmpInst::Predicate>(PredOp.getPredicate())); @@ -2506,12 +2506,12 @@ bool AArch64InstructionSelector::earlySelect(MachineInstr &I) { return false; } auto &PredOp = Cmp->getOperand(1); - auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); - const AArch64CC::CondCode InvCC = - changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred)); MIB.setInstrAndDebugLoc(I); emitIntegerCompare(/*LHS=*/Cmp->getOperand(2), /*RHS=*/Cmp->getOperand(3), PredOp, MIB); + auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); + const AArch64CC::CondCode InvCC = + changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred)); emitCSINC(/*Dst=*/AddDst, /*Src =*/AddLHS, /*Src2=*/AddLHS, InvCC, MIB); I.eraseFromParent(); return true; @@ -3574,10 +3574,11 @@ bool AArch64InstructionSelector::select(MachineInstr &I) { return false; } - auto Pred = static_cast<CmpInst::Predicate>(I.getOperand(1).getPredicate()); + auto &PredOp = I.getOperand(1); + emitIntegerCompare(I.getOperand(2), I.getOperand(3), PredOp, MIB); + auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred)); - emitIntegerCompare(I.getOperand(2), I.getOperand(3), I.getOperand(1), MIB); emitCSINC(/*Dst=*/I.getOperand(0).getReg(), /*Src1=*/AArch64::WZR, /*Src2=*/AArch64::WZR, InvCC, MIB); I.eraseFromParent(); @@ -5096,11 +5097,11 @@ bool AArch64InstructionSelector::tryOptSelect(GSelect &I) { AArch64CC::CondCode CondCode; if (CondOpc == TargetOpcode::G_ICMP) { - auto Pred = - static_cast<CmpInst::Predicate>(CondDef->getOperand(1).getPredicate()); + auto &PredOp = CondDef->getOperand(1); + emitIntegerCompare(CondDef->getOperand(2), CondDef->getOperand(3), PredOp, + MIB); + auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); CondCode = changeICMPPredToAArch64CC(Pred); - emitIntegerCompare(CondDef->getOperand(2), CondDef->getOperand(3), - CondDef->getOperand(1), MIB); } else { // Get the condition code for the select. auto Pred = @@ -5148,29 +5149,37 @@ MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare( MachineInstr *LHSDef = getDefIgnoringCopies(LHS.getReg(), MRI); MachineInstr *RHSDef = getDefIgnoringCopies(RHS.getReg(), MRI); auto P = static_cast<CmpInst::Predicate>(Predicate.getPredicate()); + // Given this: // // x = G_SUB 0, y - // G_ICMP x, z + // G_ICMP z, x // // Produce this: // - // cmn y, z - if (isCMN(LHSDef, P, MRI)) - return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder); + // cmn z, y + if (isCMN(RHSDef, P, MRI)) + return emitCMN(LHS, RHSDef->getOperand(2), MIRBuilder); - // Same idea here, but with the RHS of the compare instead: + // Same idea here, but with the LHS of the compare instead: // // Given this: // // x = G_SUB 0, y - // G_ICMP z, x + // G_ICMP x, z // // Produce this: // - // cmn z, y - if (isCMN(RHSDef, P, MRI)) - return emitCMN(LHS, RHSDef->getOperand(2), MIRBuilder); + // cmn y, z + // + // But be careful! We need to swap the predicate! + if (isCMN(LHSDef, P, MRI)) { + if (!CmpInst::isEquality(P)) { + P = CmpInst::getSwappedPredicate(P); + Predicate = MachineOperand::CreatePredicate(P); + } + return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder); + } // Given this: // diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp index 1cd9453..8c10673 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp @@ -228,12 +228,13 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI, B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset))); } -// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y)) -// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1)) +// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add([us]dot(x, y)) +// Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add([us]dot(x, y)) +// Or vecreduce_add(ext(x)) -> vecreduce_add([us]dot(x, 1)) // Similar to performVecReduceAddCombine in SelectionDAG -bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, - const AArch64Subtarget &STI, - std::tuple<Register, Register, bool> &MatchInfo) { +bool matchExtAddvToDotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, + const AArch64Subtarget &STI, + std::tuple<Register, Register, bool> &MatchInfo) { assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD && "Expected a G_VECREDUCE_ADD instruction"); assert(STI.hasDotProd() && "Target should have Dot Product feature"); @@ -246,31 +247,57 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32) return false; - LLT SrcTy; - auto I1Opc = I1->getOpcode(); - if (I1Opc == TargetOpcode::G_MUL) { + // Detect mul(ext, ext) with symmetric ext's. If I1Opc is G_ZEXT or G_SEXT + // then the ext's must match the same opcode. It is set to the ext opcode on + // output. + auto tryMatchingMulOfExt = [&MRI](MachineInstr *MI, Register &Out1, + Register &Out2, unsigned &I1Opc) { // If result of this has more than 1 use, then there is no point in creating - // udot instruction - if (!MRI.hasOneNonDBGUse(MidReg)) + // a dot instruction + if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) return false; MachineInstr *ExtMI1 = - getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI); + getDefIgnoringCopies(MI->getOperand(1).getReg(), MRI); MachineInstr *ExtMI2 = - getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI); + getDefIgnoringCopies(MI->getOperand(2).getReg(), MRI); LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg()); LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg()); if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy) return false; + if ((I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) && + I1Opc != ExtMI1->getOpcode()) + return false; + Out1 = ExtMI1->getOperand(1).getReg(); + Out2 = ExtMI2->getOperand(1).getReg(); I1Opc = ExtMI1->getOpcode(); - SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg()); - std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg(); - std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg(); + return true; + }; + + LLT SrcTy; + unsigned I1Opc = I1->getOpcode(); + if (I1Opc == TargetOpcode::G_MUL) { + Register Out1, Out2; + if (!tryMatchingMulOfExt(I1, Out1, Out2, I1Opc)) + return false; + SrcTy = MRI.getType(Out1); + std::get<0>(MatchInfo) = Out1; + std::get<1>(MatchInfo) = Out2; } else if (I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) { - SrcTy = MRI.getType(I1->getOperand(1).getReg()); - std::get<0>(MatchInfo) = I1->getOperand(1).getReg(); - std::get<1>(MatchInfo) = 0; + Register I1Op = I1->getOperand(1).getReg(); + MachineInstr *M = getDefIgnoringCopies(I1Op, MRI); + Register Out1, Out2; + if (M->getOpcode() == TargetOpcode::G_MUL && + tryMatchingMulOfExt(M, Out1, Out2, I1Opc)) { + SrcTy = MRI.getType(Out1); + std::get<0>(MatchInfo) = Out1; + std::get<1>(MatchInfo) = Out2; + } else { + SrcTy = MRI.getType(I1Op); + std::get<0>(MatchInfo) = I1Op; + std::get<1>(MatchInfo) = 0; + } } else { return false; } @@ -288,11 +315,11 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, return true; } -void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, - MachineIRBuilder &Builder, - GISelChangeObserver &Observer, - const AArch64Subtarget &STI, - std::tuple<Register, Register, bool> &MatchInfo) { +void applyExtAddvToDotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, + MachineIRBuilder &Builder, + GISelChangeObserver &Observer, + const AArch64Subtarget &STI, + std::tuple<Register, Register, bool> &MatchInfo) { assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD && "Expected a G_VECREDUCE_ADD instruction"); assert(STI.hasDotProd() && "Target should have Dot Product feature"); @@ -553,15 +580,15 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, MI.eraseFromParent(); } -// Pushes ADD/SUB through extend instructions to decrease the number of extend -// instruction at the end by allowing selection of {s|u}addl sooner - +// Pushes ADD/SUB/MUL through extend instructions to decrease the number of +// extend instruction at the end by allowing selection of {s|u}addl sooner // i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8)) bool matchPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI, Register DstReg, Register SrcReg1, Register SrcReg2) { assert((MI.getOpcode() == TargetOpcode::G_ADD || - MI.getOpcode() == TargetOpcode::G_SUB) && - "Expected a G_ADD or G_SUB instruction\n"); + MI.getOpcode() == TargetOpcode::G_SUB || + MI.getOpcode() == TargetOpcode::G_MUL) && + "Expected a G_ADD, G_SUB or G_MUL instruction\n"); // Deal with vector types only LLT DstTy = MRI.getType(DstReg); @@ -594,9 +621,10 @@ void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI, B.buildInstr(MI.getOpcode(), {MidTy}, {Ext1Reg, Ext2Reg}).getReg(0); // G_SUB has to sign-extend the result. - // G_ADD needs to sext from sext and can sext or zext from zext, so the - // original opcode is used. - if (MI.getOpcode() == TargetOpcode::G_ADD) + // G_ADD needs to sext from sext and can sext or zext from zext, and G_MUL + // needs to use the original opcode so the original opcode is used for both. + if (MI.getOpcode() == TargetOpcode::G_ADD || + MI.getOpcode() == TargetOpcode::G_MUL) B.buildInstr(Opc, {DstReg}, {AddReg}); else B.buildSExt(DstReg, AddReg); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp index b9d3e1b..7a2b679 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp @@ -79,8 +79,7 @@ public: } void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool fixupNeedsRelaxation(const MCFixup &Fixup, uint64_t Value) const override; @@ -421,9 +420,8 @@ static bool shouldForceRelocation(const MCFixup &Fixup) { } void AArch64AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (shouldForceRelocation(Fixup)) IsResolved = false; maybeAddReloc(F, Fixup, Target, Value, IsResolved); @@ -460,8 +458,8 @@ void AArch64AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Shift the value into position. Value <<= Info.TargetOffset; - unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // Used to point to big endian bytes. unsigned FulleSizeInBytes = getFixupKindContainereSizeInBytes(Fixup.getKind()); @@ -471,15 +469,16 @@ void AArch64AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, if (FulleSizeInBytes == 0) { // Handle as little-endian for (unsigned i = 0; i != NumBytes; ++i) { - Data[Offset + i] |= uint8_t((Value >> (i * 8)) & 0xff); + Data[i] |= uint8_t((Value >> (i * 8)) & 0xff); } } else { // Handle as big-endian - assert((Offset + FulleSizeInBytes) <= Data.size() && "Invalid fixup size!"); + assert(Fixup.getOffset() + FulleSizeInBytes <= F.getSize() && + "Invalid fixup size!"); assert(NumBytes <= FulleSizeInBytes && "Invalid fixup size!"); for (unsigned i = 0; i != NumBytes; ++i) { unsigned Idx = FulleSizeInBytes - 1 - i; - Data[Offset + Idx] |= uint8_t((Value >> (i * 8)) & 0xff); + Data[Idx] |= uint8_t((Value >> (i * 8)) & 0xff); } } @@ -492,9 +491,9 @@ void AArch64AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // If the immediate is negative, generate MOVN else MOVZ. // (Bit 30 = 0) ==> MOVN, (Bit 30 = 1) ==> MOVZ. if (SignedValue < 0) - Data[Offset + 3] &= ~(1 << 6); + Data[3] &= ~(1 << 6); else - Data[Offset + 3] |= (1 << 6); + Data[3] |= (1 << 6); } } diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td index 071c940..18f3c47 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPU.td +++ b/llvm/lib/Target/AMDGPU/AMDGPU.td @@ -1160,6 +1160,12 @@ def FeatureTanhInsts : SubtargetFeature<"tanh-insts", "Has v_tanh_f32/f16 instructions" >; +def FeatureTensorCvtLutInsts : SubtargetFeature<"tensor-cvt-lut-insts", + "HasTensorCvtLutInsts", + "true", + "Has v_perm_pk16* instructions" +>; + def FeatureTransposeLoadF4F6Insts : SubtargetFeature<"transpose-load-f4f6-insts", "HasTransposeLoadF4F6Insts", "true", @@ -2030,6 +2036,7 @@ def FeatureISAVersion12_50 : FeatureSet< FeatureDPPSrc1SGPR, FeatureBitOp3Insts, FeatureTanhInsts, + FeatureTensorCvtLutInsts, FeatureTransposeLoadF4F6Insts, FeatureBF16TransInsts, FeatureBF16ConversionInsts, @@ -2576,6 +2583,10 @@ def HasFmaakFmamkF64Insts : Predicate<"Subtarget->hasFmaakFmamkF64Insts()">, AssemblerPredicate<(any_of FeatureGFX1250Insts)>; +def HasAddMinMaxInsts : + Predicate<"Subtarget->hasAddMinMaxInsts()">, + AssemblerPredicate<(any_of FeatureGFX1250Insts)>; + def HasPkAddMinMaxInsts : Predicate<"Subtarget->hasPkAddMinMaxInsts()">, AssemblerPredicate<(any_of FeatureGFX1250Insts)>; @@ -2781,6 +2792,9 @@ def HasBitOp3Insts : Predicate<"Subtarget->hasBitOp3Insts()">, def HasTanhInsts : Predicate<"Subtarget->hasTanhInsts()">, AssemblerPredicate<(all_of FeatureTanhInsts)>; +def HasTensorCvtLutInsts : Predicate<"Subtarget->hasTensorCvtLutInsts()">, + AssemblerPredicate<(all_of FeatureTensorCvtLutInsts)>; + def HasTransposeLoadF4F6Insts : Predicate<"Subtarget->hasTransposeLoadF4F6Insts()">, AssemblerPredicate<(all_of FeatureTransposeLoadF4F6Insts)>; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp index 5f19837..a9278c1 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp @@ -89,10 +89,6 @@ static cl::opt<bool> DisableFDivExpand( cl::ReallyHidden, cl::init(false)); -static bool hasUnsafeFPMath(const Function &F) { - return F.getFnAttribute("unsafe-fp-math").getValueAsBool(); -} - class AMDGPUCodeGenPrepareImpl : public InstVisitor<AMDGPUCodeGenPrepareImpl, bool> { public: @@ -104,7 +100,6 @@ public: const DominatorTree *DT; const UniformityInfo &UA; const DataLayout &DL; - const bool HasUnsafeFPMath; const bool HasFP32DenormalFlush; bool FlowChanged = false; mutable Function *SqrtF32 = nullptr; @@ -117,7 +112,6 @@ public: const DominatorTree *DT, const UniformityInfo &UA) : F(F), ST(TM.getSubtarget<GCNSubtarget>(F)), TM(TM), TLI(TLI), AC(AC), DT(DT), UA(UA), DL(F.getDataLayout()), - HasUnsafeFPMath(hasUnsafeFPMath(F)), HasFP32DenormalFlush(SIModeRegisterDefaults(F, ST).FP32Denormals == DenormalMode::getPreserveSign()) {} @@ -637,8 +631,7 @@ bool AMDGPUCodeGenPrepareImpl::canOptimizeWithRsq(const FPMathOperator *SqrtOp, return false; // v_rsq_f32 gives 1ulp - return SqrtFMF.approxFunc() || HasUnsafeFPMath || - SqrtOp->getFPAccuracy() >= 1.0f; + return SqrtFMF.approxFunc() || SqrtOp->getFPAccuracy() >= 1.0f; } Value *AMDGPUCodeGenPrepareImpl::optimizeWithRsq( @@ -664,7 +657,7 @@ Value *AMDGPUCodeGenPrepareImpl::optimizeWithRsq( IRBuilder<>::FastMathFlagGuard Guard(Builder); Builder.setFastMathFlags(DivFMF | SqrtFMF); - if ((DivFMF.approxFunc() && SqrtFMF.approxFunc()) || HasUnsafeFPMath || + if ((DivFMF.approxFunc() && SqrtFMF.approxFunc()) || canIgnoreDenormalInput(Den, CtxI)) { Value *Result = Builder.CreateUnaryIntrinsic(Intrinsic::amdgcn_rsq, Den); // -1.0 / sqrt(x) -> fneg(rsq(x)) @@ -680,7 +673,7 @@ Value *AMDGPUCodeGenPrepareImpl::optimizeWithRsq( // Optimize fdiv with rcp: // // 1/x -> rcp(x) when rcp is sufficiently accurate or inaccurate rcp is -// allowed with unsafe-fp-math or afn. +// allowed with afn. // // a/b -> a*rcp(b) when arcp is allowed, and we only need provide ULP 1.0 Value * @@ -803,9 +796,9 @@ Value *AMDGPUCodeGenPrepareImpl::visitFDivElement( // // With rcp: // 1/x -> rcp(x) when rcp is sufficiently accurate or inaccurate rcp is -// allowed with unsafe-fp-math or afn. +// allowed with afn. // -// a/b -> a*rcp(b) when inaccurate rcp is allowed with unsafe-fp-math or afn. +// a/b -> a*rcp(b) when inaccurate rcp is allowed with afn. // // With fdiv.fast: // a/b -> fdiv.fast(a, b) when !fpmath >= 2.5ulp with denormals flushed. @@ -843,7 +836,7 @@ bool AMDGPUCodeGenPrepareImpl::visitFDiv(BinaryOperator &FDiv) { RsqOp = SqrtOp->getOperand(0); } - // Inaccurate rcp is allowed with unsafe-fp-math or afn. + // Inaccurate rcp is allowed with afn. // // Defer to codegen to handle this. // @@ -852,7 +845,7 @@ bool AMDGPUCodeGenPrepareImpl::visitFDiv(BinaryOperator &FDiv) { // expansion of afn to codegen. The current interpretation is so aggressive we // don't need any pre-consideration here when we have better information. A // more conservative interpretation could use handling here. - const bool AllowInaccurateRcp = HasUnsafeFPMath || DivFMF.approxFunc(); + const bool AllowInaccurateRcp = DivFMF.approxFunc(); if (!RsqOp && AllowInaccurateRcp) return false; @@ -2026,7 +2019,7 @@ bool AMDGPUCodeGenPrepareImpl::visitSqrt(IntrinsicInst &Sqrt) { // We're trying to handle the fast-but-not-that-fast case only. The lowering // of fast llvm.sqrt will give the raw instruction anyway. - if (SqrtFMF.approxFunc() || HasUnsafeFPMath) + if (SqrtFMF.approxFunc()) return false; const float ReqdAccuracy = FPOp->getFPAccuracy(); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp index 2991778..19b8757 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp @@ -204,7 +204,7 @@ MetadataStreamerMsgPackV4::getWorkGroupDimensions(MDNode *Node) const { for (auto &Op : Node->operands()) Dims.push_back(Dims.getDocument()->getNode( - uint64_t(mdconst::extract<ConstantInt>(Op)->getZExtValue()))); + mdconst::extract<ConstantInt>(Op)->getZExtValue())); return Dims; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 6118933..d059480 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -367,6 +367,18 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, setTruncStoreAction(MVT::v4f64, MVT::v4bf16, Expand); setTruncStoreAction(MVT::v4f64, MVT::v4f16, Expand); + setTruncStoreAction(MVT::v5i32, MVT::v5i1, Expand); + setTruncStoreAction(MVT::v5i32, MVT::v5i8, Expand); + setTruncStoreAction(MVT::v5i32, MVT::v5i16, Expand); + + setTruncStoreAction(MVT::v6i32, MVT::v6i1, Expand); + setTruncStoreAction(MVT::v6i32, MVT::v6i8, Expand); + setTruncStoreAction(MVT::v6i32, MVT::v6i16, Expand); + + setTruncStoreAction(MVT::v7i32, MVT::v7i1, Expand); + setTruncStoreAction(MVT::v7i32, MVT::v7i8, Expand); + setTruncStoreAction(MVT::v7i32, MVT::v7i16, Expand); + setTruncStoreAction(MVT::v8f64, MVT::v8f32, Expand); setTruncStoreAction(MVT::v8f64, MVT::v8bf16, Expand); setTruncStoreAction(MVT::v8f64, MVT::v8f16, Expand); @@ -2634,7 +2646,7 @@ bool AMDGPUTargetLowering::allowApproxFunc(const SelectionDAG &DAG, if (Flags.hasApproximateFuncs()) return true; auto &Options = DAG.getTarget().Options; - return Options.UnsafeFPMath || Options.ApproxFuncFPMath; + return Options.ApproxFuncFPMath; } bool AMDGPUTargetLowering::needsDenormHandlingF32(const SelectionDAG &DAG, @@ -2757,7 +2769,7 @@ SDValue AMDGPUTargetLowering::LowerFLOGCommon(SDValue Op, const auto &Options = getTargetMachine().Options; if (VT == MVT::f16 || Flags.hasApproximateFuncs() || - Options.ApproxFuncFPMath || Options.UnsafeFPMath) { + Options.ApproxFuncFPMath) { if (VT == MVT::f16 && !Subtarget->has16BitInsts()) { // Log and multiply in f32 is good enough for f16. @@ -3585,7 +3597,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con if (N0.getValueType() == MVT::f32) return DAG.getNode(AMDGPUISD::FP_TO_FP16, DL, Op.getValueType(), N0); - if (getTargetMachine().Options.UnsafeFPMath) { + if (Op->getFlags().hasApproximateFuncs()) { // There is a generic expand for FP_TO_FP16 with unsafe fast math. return SDValue(); } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp index cd9c2ec..b0d3b12 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -6994,13 +6994,13 @@ void AMDGPUInstructionSelector::renderSrcAndDstSelToOpSelXForm_0_0( MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const { assert(OpIdx >= 0 && "expected to match an immediate operand"); MIB.addImm( - (MI.getOperand(OpIdx).getImm() & 0x2) ? (int64_t)SISrcMods::OP_SEL_0 : 0); + (MI.getOperand(OpIdx).getImm() & 0x1) ? (int64_t)SISrcMods::OP_SEL_0 : 0); } void AMDGPUInstructionSelector::renderSrcAndDstSelToOpSelXForm_0_1( MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const { assert(OpIdx >= 0 && "expected to match an immediate operand"); - MIB.addImm((MI.getOperand(OpIdx).getImm() & 0x2) + MIB.addImm((MI.getOperand(OpIdx).getImm() & 0x1) ? (int64_t)(SISrcMods::OP_SEL_0 | SISrcMods::DST_OP_SEL) : (int64_t)SISrcMods::DST_OP_SEL); } @@ -7009,13 +7009,13 @@ void AMDGPUInstructionSelector::renderSrcAndDstSelToOpSelXForm_1_0( MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const { assert(OpIdx >= 0 && "expected to match an immediate operand"); MIB.addImm( - (MI.getOperand(OpIdx).getImm() & 0x1) ? (int64_t)SISrcMods::OP_SEL_0 : 0); + (MI.getOperand(OpIdx).getImm() & 0x2) ? (int64_t)SISrcMods::OP_SEL_0 : 0); } void AMDGPUInstructionSelector::renderSrcAndDstSelToOpSelXForm_1_1( MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const { assert(OpIdx >= 0 && "expected to match an immediate operand"); - MIB.addImm((MI.getOperand(OpIdx).getImm() & 0x1) + MIB.addImm((MI.getOperand(OpIdx).getImm() & 0x2) ? (int64_t)(SISrcMods::OP_SEL_0) : 0); } @@ -7044,8 +7044,9 @@ void AMDGPUInstructionSelector::renderSrcAndDstSelToOpSelXForm_2_0( void AMDGPUInstructionSelector::renderDstSelToOpSel3XFormXForm( MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const { assert(OpIdx >= 0 && "expected to match an immediate operand"); - MIB.addImm( - (MI.getOperand(OpIdx).getImm() & 0x2) ? (int64_t)SISrcMods::DST_OP_SEL : 0); + MIB.addImm((MI.getOperand(OpIdx).getImm() & 0x2) + ? (int64_t)SISrcMods::DST_OP_SEL + : 0); } void AMDGPUInstructionSelector::renderExtractCPol(MachineInstrBuilder &MIB, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td b/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td index 7a50923..511fc69 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td @@ -94,7 +94,6 @@ def NoFP32Denormals : Predicate<"MF->getInfo<SIMachineFunctionInfo>()->getMode() def NoFP64Denormals : Predicate<"MF->getInfo<SIMachineFunctionInfo>()->getMode().FP64FP16Denormals == DenormalMode::getPreserveSign()">; def IEEEModeEnabled : Predicate<"MF->getInfo<SIMachineFunctionInfo>()->getMode().IEEE">; def IEEEModeDisabled : Predicate<"!MF->getInfo<SIMachineFunctionInfo>()->getMode().IEEE">; -def UnsafeFPMath : Predicate<"TM.Options.UnsafeFPMath">; } def FMA : Predicate<"Subtarget->hasFMA()">; diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp index 50da8fd..1fdf272 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -3344,7 +3344,7 @@ static bool allowApproxFunc(const MachineFunction &MF, unsigned Flags) { if (Flags & MachineInstr::FmAfn) return true; const auto &Options = MF.getTarget().Options; - return Options.UnsafeFPMath || Options.ApproxFuncFPMath; + return Options.ApproxFuncFPMath; } static bool needsDenormHandlingF32(const MachineFunction &MF, Register Src, @@ -3450,7 +3450,7 @@ bool AMDGPULegalizerInfo::legalizeFlogCommon(MachineInstr &MI, static_cast<const AMDGPUTargetMachine &>(MF.getTarget()); if (Ty == F16 || MI.getFlag(MachineInstr::FmAfn) || - TM.Options.ApproxFuncFPMath || TM.Options.UnsafeFPMath) { + TM.Options.ApproxFuncFPMath) { if (Ty == F16 && !ST.has16BitInsts()) { Register LogVal = MRI.createGenericVirtualRegister(F32); auto PromoteSrc = B.buildFPExt(F32, X); @@ -4877,9 +4877,7 @@ bool AMDGPULegalizerInfo::legalizeFastUnsafeFDIV(MachineInstr &MI, uint16_t Flags = MI.getFlags(); LLT ResTy = MRI.getType(Res); - const MachineFunction &MF = B.getMF(); - bool AllowInaccurateRcp = MI.getFlag(MachineInstr::FmAfn) || - MF.getTarget().Options.UnsafeFPMath; + bool AllowInaccurateRcp = MI.getFlag(MachineInstr::FmAfn); if (const auto *CLHS = getConstantFPVRegVal(LHS, MRI)) { if (!AllowInaccurateRcp && ResTy != LLT::scalar(16)) @@ -4939,9 +4937,7 @@ bool AMDGPULegalizerInfo::legalizeFastUnsafeFDIV64(MachineInstr &MI, uint16_t Flags = MI.getFlags(); LLT ResTy = MRI.getType(Res); - const MachineFunction &MF = B.getMF(); - bool AllowInaccurateRcp = MF.getTarget().Options.UnsafeFPMath || - MI.getFlag(MachineInstr::FmAfn); + bool AllowInaccurateRcp = MI.getFlag(MachineInstr::FmAfn); if (!AllowInaccurateRcp) return false; diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp index 8767208..aa75534 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp @@ -53,8 +53,6 @@ private: using FuncInfo = llvm::AMDGPULibFunc; - bool UnsafeFPMath = false; - // -fuse-native. bool AllNative = false; @@ -117,7 +115,6 @@ private: bool AllowStrictFP = false); protected: - bool isUnsafeMath(const FPMathOperator *FPOp) const; bool isUnsafeFiniteOnlyMath(const FPMathOperator *FPOp) const; bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const; @@ -415,23 +412,17 @@ bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName, return AMDGPULibFunc::parse(FMangledName, FInfo); } -bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const { - return UnsafeFPMath || FPOp->isFast(); -} - bool AMDGPULibCalls::isUnsafeFiniteOnlyMath(const FPMathOperator *FPOp) const { - return UnsafeFPMath || - (FPOp->hasApproxFunc() && FPOp->hasNoNaNs() && FPOp->hasNoInfs()); + return FPOp->hasApproxFunc() && FPOp->hasNoNaNs() && FPOp->hasNoInfs(); } bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold( const FPMathOperator *FPOp) const { // TODO: Refine to approxFunc or contract - return isUnsafeMath(FPOp); + return FPOp->isFast(); } void AMDGPULibCalls::initFunction(Function &F, FunctionAnalysisManager &FAM) { - UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool(); AC = &FAM.getResult<AssumptionAnalysis>(F); TLInfo = &FAM.getResult<TargetLibraryAnalysis>(F); DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp index 6c40eb5..d11e5a3 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp @@ -3204,6 +3204,18 @@ void AMDGPURegisterBankInfo::applyMappingImpl( constrainOpWithReadfirstlane(B, MI, 5); return; } + case Intrinsic::amdgcn_permlane_bcast: + case Intrinsic::amdgcn_permlane_up: + case Intrinsic::amdgcn_permlane_down: + case Intrinsic::amdgcn_permlane_xor: + // Doing a waterfall loop over these wouldn't make any sense. + constrainOpWithReadfirstlane(B, MI, 3); + constrainOpWithReadfirstlane(B, MI, 4); + return; + case Intrinsic::amdgcn_permlane_idx_gen: { + constrainOpWithReadfirstlane(B, MI, 3); + return; + } case Intrinsic::amdgcn_sbfe: applyMappingBFE(B, OpdMapper, true); return; @@ -4574,8 +4586,59 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_cvt_pknorm_u16: case Intrinsic::amdgcn_cvt_pk_i16: case Intrinsic::amdgcn_cvt_pk_u16: + case Intrinsic::amdgcn_cvt_sr_pk_f16_f32: + case Intrinsic::amdgcn_cvt_sr_pk_bf16_f32: case Intrinsic::amdgcn_cvt_pk_f16_fp8: case Intrinsic::amdgcn_cvt_pk_f16_bf8: + case Intrinsic::amdgcn_cvt_pk_fp8_f16: + case Intrinsic::amdgcn_cvt_pk_bf8_f16: + case Intrinsic::amdgcn_cvt_sr_fp8_f16: + case Intrinsic::amdgcn_cvt_sr_bf8_f16: + case Intrinsic::amdgcn_cvt_scale_pk8_f16_fp8: + case Intrinsic::amdgcn_cvt_scale_pk8_bf16_fp8: + case Intrinsic::amdgcn_cvt_scale_pk8_f16_bf8: + case Intrinsic::amdgcn_cvt_scale_pk8_bf16_bf8: + case Intrinsic::amdgcn_cvt_scale_pk8_f16_fp4: + case Intrinsic::amdgcn_cvt_scale_pk8_bf16_fp4: + case Intrinsic::amdgcn_cvt_scale_pk8_f32_fp8: + case Intrinsic::amdgcn_cvt_scale_pk8_f32_bf8: + case Intrinsic::amdgcn_cvt_scale_pk8_f32_fp4: + case Intrinsic::amdgcn_cvt_scale_pk16_f16_fp6: + case Intrinsic::amdgcn_cvt_scale_pk16_bf16_fp6: + case Intrinsic::amdgcn_cvt_scale_pk16_f16_bf6: + case Intrinsic::amdgcn_cvt_scale_pk16_bf16_bf6: + case Intrinsic::amdgcn_cvt_scale_pk16_f32_fp6: + case Intrinsic::amdgcn_cvt_scale_pk16_f32_bf6: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp8_bf16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_bf8_bf16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp8_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_bf8_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp8_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk8_bf8_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp4_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp4_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp4_bf16: + case Intrinsic::amdgcn_cvt_scalef32_pk16_fp6_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk16_bf6_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk16_fp6_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk16_bf6_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk16_fp6_bf16: + case Intrinsic::amdgcn_cvt_scalef32_pk16_bf6_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp8_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_bf8_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp8_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_bf8_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp8_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_bf8_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp4_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp4_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp4_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_fp6_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_bf6_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_fp6_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_bf6_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_fp6_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_bf6_bf16: case Intrinsic::amdgcn_sat_pk4_i4_i8: case Intrinsic::amdgcn_sat_pk4_u4_u8: case Intrinsic::amdgcn_fmed3: @@ -4627,8 +4690,10 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_cvt_pk_f32_fp8: case Intrinsic::amdgcn_cvt_pk_f32_bf8: case Intrinsic::amdgcn_cvt_pk_fp8_f32: + case Intrinsic::amdgcn_cvt_pk_fp8_f32_e5m3: case Intrinsic::amdgcn_cvt_pk_bf8_f32: case Intrinsic::amdgcn_cvt_sr_fp8_f32: + case Intrinsic::amdgcn_cvt_sr_fp8_f32_e5m3: case Intrinsic::amdgcn_cvt_sr_bf8_f32: case Intrinsic::amdgcn_cvt_sr_bf16_f32: case Intrinsic::amdgcn_cvt_sr_f16_f32: @@ -4748,6 +4813,9 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_swmmac_f16_16x16x128_bf8_fp8: case Intrinsic::amdgcn_swmmac_f16_16x16x128_bf8_bf8: case Intrinsic::amdgcn_swmmac_i32_16x16x128_iu8: + case Intrinsic::amdgcn_perm_pk16_b4_u4: + case Intrinsic::amdgcn_perm_pk16_b6_u4: + case Intrinsic::amdgcn_perm_pk16_b8_u4: return getDefaultMappingVOP(MI); case Intrinsic::amdgcn_log: case Intrinsic::amdgcn_exp2: @@ -4885,6 +4953,24 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { OpdsMapping[5] = getSGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI); break; } + case Intrinsic::amdgcn_permlane_bcast: + case Intrinsic::amdgcn_permlane_up: + case Intrinsic::amdgcn_permlane_down: + case Intrinsic::amdgcn_permlane_xor: { + unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, *TRI); + OpdsMapping[0] = AMDGPU::getValueMapping(AMDGPU::VGPRRegBankID, Size); + OpdsMapping[2] = AMDGPU::getValueMapping(AMDGPU::VGPRRegBankID, Size); + OpdsMapping[3] = getSGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI); + OpdsMapping[4] = getSGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI); + break; + } + case Intrinsic::amdgcn_permlane_idx_gen: { + unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, *TRI); + OpdsMapping[0] = AMDGPU::getValueMapping(AMDGPU::VGPRRegBankID, Size); + OpdsMapping[2] = AMDGPU::getValueMapping(AMDGPU::VGPRRegBankID, Size); + OpdsMapping[3] = getSGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI); + break; + } case Intrinsic::amdgcn_permlane16_var: case Intrinsic::amdgcn_permlanex16_var: { unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, *TRI); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td b/llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td index dfe0cbf..10b8606 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td @@ -321,6 +321,11 @@ def : SourceOfDivergence<int_amdgcn_permlane16>; def : SourceOfDivergence<int_amdgcn_permlanex16>; def : SourceOfDivergence<int_amdgcn_permlane16_var>; def : SourceOfDivergence<int_amdgcn_permlanex16_var>; +def : SourceOfDivergence<int_amdgcn_permlane_bcast>; +def : SourceOfDivergence<int_amdgcn_permlane_up>; +def : SourceOfDivergence<int_amdgcn_permlane_down>; +def : SourceOfDivergence<int_amdgcn_permlane_xor>; +def : SourceOfDivergence<int_amdgcn_permlane_idx_gen>; def : SourceOfDivergence<int_amdgcn_mov_dpp>; def : SourceOfDivergence<int_amdgcn_mov_dpp8>; def : SourceOfDivergence<int_amdgcn_update_dpp>; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp index 24f4df2..a0c99b0 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -597,7 +597,6 @@ InstructionCost GCNTTIImpl::getArithmeticInstrCost( // Estimate all types may be fused with contract/unsafe flags const TargetOptions &Options = TLI->getTargetMachine().Options; if (Options.AllowFPOpFusion == FPOpFusion::Fast || - Options.UnsafeFPMath || (FAdd->hasAllowContract() && CxtI->hasAllowContract())) return TargetTransformInfo::TCC_Free; } @@ -650,8 +649,7 @@ InstructionCost GCNTTIImpl::getArithmeticInstrCost( return LT.first * Cost * NElts; } - if (SLT == MVT::f32 && ((CxtI && CxtI->hasApproxFunc()) || - TLI->getTargetMachine().Options.UnsafeFPMath)) { + if (SLT == MVT::f32 && (CxtI && CxtI->hasApproxFunc())) { // Fast unsafe fdiv lowering: // f32 rcp // f32 fmul diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index 44e65b3..a83caa0 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -180,6 +180,7 @@ public: ImmTyMatrixBFMT, ImmTyMatrixAReuse, ImmTyMatrixBReuse, + ImmTyScaleSel, ImmTyByteSel, }; @@ -689,6 +690,8 @@ public: bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); } + bool isVSrc_NoInline_v2f16() const { return isVSrc_v2f16(); } + bool isVISrcB32() const { return isRegOrInlineNoMods(AMDGPU::VGPR_32RegClassID, MVT::i32); } @@ -1182,6 +1185,7 @@ public: case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break; case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break; case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break; + case ImmTyScaleSel: OS << "ScaleSel" ; break; case ImmTyByteSel: OS << "ByteSel" ; break; } // clang-format on @@ -2036,6 +2040,7 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) { case AMDGPU::OPERAND_REG_INLINE_C_FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_IMM_V2FP16: + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: case AMDGPU::OPERAND_KIMM16: return &APFloat::IEEEhalf(); case AMDGPU::OPERAND_REG_IMM_BF16: @@ -2405,6 +2410,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_IMM_V2INT16: case AMDGPU::OPERAND_REG_IMM_V2FP16: + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: case AMDGPU::OPERAND_REG_IMM_V2FP32: case AMDGPU::OPERAND_REG_IMM_V2INT32: case AMDGPU::OPERAND_KIMM32: @@ -2456,6 +2462,9 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo setImmKindConst(); return; } + [[fallthrough]]; + + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: Inst.addOperand(MCOperand::createImm(Lo_32(Val))); setImmKindLiteral(); @@ -3761,6 +3770,9 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst, OperandType == AMDGPU::OPERAND_REG_INLINE_C_BF16) return AMDGPU::isInlinableLiteralBF16(Val, hasInv2PiInlineImm()); + if (OperandType == AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16) + return false; + llvm_unreachable("invalid operand type"); } default: @@ -9356,6 +9368,14 @@ void AMDGPUAsmParser::cvtVOP3(MCInst &Inst, const OperandVector &Operands, } } + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::scale_sel)) + addOptionalImmOperand(Inst, Operands, OptionalIdx, + AMDGPUOperand::ImmTyScaleSel); + + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::clamp)) + addOptionalImmOperand(Inst, Operands, OptionalIdx, + AMDGPUOperand::ImmTyClamp); + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::byte_sel)) { if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::vdst_in)) Inst.addOperand(Inst.getOperand(0)); @@ -9363,10 +9383,6 @@ void AMDGPUAsmParser::cvtVOP3(MCInst &Inst, const OperandVector &Operands, AMDGPUOperand::ImmTyByteSel); } - if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::clamp)) - addOptionalImmOperand(Inst, Operands, OptionalIdx, - AMDGPUOperand::ImmTyClamp); - if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::omod)) addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyOModSI); @@ -9420,8 +9436,22 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands, Opc == AMDGPU::V_CVT_PK_FP8_F32_fake16_e64_dpp8_gfx12 || Opc == AMDGPU::V_CVT_SR_FP8_F32_gfx12_e64_dpp_gfx12 || Opc == AMDGPU::V_CVT_SR_FP8_F32_gfx12_e64_dpp8_gfx12 || + Opc == AMDGPU::V_CVT_SR_FP8_F32_gfx1250_e64_dpp_gfx1250 || + Opc == AMDGPU::V_CVT_SR_FP8_F32_gfx1250_e64_dpp8_gfx1250 || Opc == AMDGPU::V_CVT_SR_BF8_F32_gfx12_e64_dpp_gfx12 || - Opc == AMDGPU::V_CVT_SR_BF8_F32_gfx12_e64_dpp8_gfx12)) { + Opc == AMDGPU::V_CVT_SR_BF8_F32_gfx12_e64_dpp8_gfx12 || + Opc == AMDGPU::V_CVT_SR_FP8_F16_t16_e64_dpp_gfx1250 || + Opc == AMDGPU::V_CVT_SR_FP8_F16_fake16_e64_dpp_gfx1250 || + Opc == AMDGPU::V_CVT_SR_FP8_F16_t16_e64_dpp8_gfx1250 || + Opc == AMDGPU::V_CVT_SR_FP8_F16_fake16_e64_dpp8_gfx1250 || + Opc == AMDGPU::V_CVT_SR_FP8_F16_t16_e64_gfx1250 || + Opc == AMDGPU::V_CVT_SR_FP8_F16_fake16_e64_gfx1250 || + Opc == AMDGPU::V_CVT_SR_BF8_F16_t16_e64_dpp_gfx1250 || + Opc == AMDGPU::V_CVT_SR_BF8_F16_fake16_e64_dpp_gfx1250 || + Opc == AMDGPU::V_CVT_SR_BF8_F16_t16_e64_dpp8_gfx1250 || + Opc == AMDGPU::V_CVT_SR_BF8_F16_fake16_e64_dpp8_gfx1250 || + Opc == AMDGPU::V_CVT_SR_BF8_F16_t16_e64_gfx1250 || + Opc == AMDGPU::V_CVT_SR_BF8_F16_fake16_e64_gfx1250)) { Inst.addOperand(Inst.getOperand(0)); } @@ -10016,9 +10046,12 @@ void AMDGPUAsmParser::cvtVOP3DPP(MCInst &Inst, const OperandVector &Operands, addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyClamp); - if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::byte_sel)) + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::byte_sel)) { + if (VdstInIdx == static_cast<int>(Inst.getNumOperands())) + Inst.addOperand(Inst.getOperand(0)); addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyByteSel); + } if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::omod)) addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyOModSI); diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp index 94886b0..96cb5ae 100644 --- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp +++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp @@ -152,7 +152,12 @@ static bool isPermlane(const MachineInstr &MI) { Opcode == AMDGPU::V_PERMLANE16_SWAP_B32_e32 || Opcode == AMDGPU::V_PERMLANE16_SWAP_B32_e64 || Opcode == AMDGPU::V_PERMLANE32_SWAP_B32_e32 || - Opcode == AMDGPU::V_PERMLANE32_SWAP_B32_e64; + Opcode == AMDGPU::V_PERMLANE32_SWAP_B32_e64 || + Opcode == AMDGPU::V_PERMLANE_BCAST_B32_e64 || + Opcode == AMDGPU::V_PERMLANE_UP_B32_e64 || + Opcode == AMDGPU::V_PERMLANE_DOWN_B32_e64 || + Opcode == AMDGPU::V_PERMLANE_XOR_B32_e64 || + Opcode == AMDGPU::V_PERMLANE_IDX_GEN_B32_e64; } static bool isLdsDma(const MachineInstr &MI) { diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp index ce1ce68..96d5668 100644 --- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp @@ -592,10 +592,13 @@ bool GCNMaxILPSchedStrategy::tryCandidate(SchedCandidate &Cand, // This is a best effort to set things up for a post-RA pass. Optimizations // like generating loads of multiple registers should ideally be done within // the scheduler pass by combining the loads during DAG postprocessing. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; @@ -666,10 +669,13 @@ bool GCNMaxMemoryClauseSchedStrategy::tryCandidate(SchedCandidate &Cand, // MaxMemoryClause-specific: We prioritize clustered instructions as we would // get more benefit from clausing these memory instructions. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; @@ -896,15 +902,10 @@ GCNScheduleDAGMILive::getRegionLiveInMap() const { assert(!Regions.empty()); std::vector<MachineInstr *> RegionFirstMIs; RegionFirstMIs.reserve(Regions.size()); - auto I = Regions.rbegin(), E = Regions.rend(); - do { - const MachineBasicBlock *MBB = I->first->getParent(); - auto *MI = &*skipDebugInstructionsForward(I->first, I->second); - RegionFirstMIs.push_back(MI); - do { - ++I; - } while (I != E && I->first->getParent() == MBB); - } while (I != E); + for (auto &[RegionBegin, RegionEnd] : reverse(Regions)) + RegionFirstMIs.push_back( + &*skipDebugInstructionsForward(RegionBegin, RegionEnd)); + return getLiveRegMap(RegionFirstMIs, /*After=*/false, *LIS); } @@ -941,11 +942,9 @@ void GCNScheduleDAGMILive::finalizeSchedule() { Pressure.resize(Regions.size()); RegionsWithHighRP.resize(Regions.size()); RegionsWithExcessRP.resize(Regions.size()); - RegionsWithMinOcc.resize(Regions.size()); RegionsWithIGLPInstrs.resize(Regions.size()); RegionsWithHighRP.reset(); RegionsWithExcessRP.reset(); - RegionsWithMinOcc.reset(); RegionsWithIGLPInstrs.reset(); runSchedStages(); @@ -1095,8 +1094,7 @@ bool PreRARematStage::initGCNSchedStage() { // fixed if there is another pass after this pass. assert(!S.hasNextStage()); - if (!GCNSchedStage::initGCNSchedStage() || DAG.RegionsWithMinOcc.none() || - DAG.Regions.size() == 1) + if (!GCNSchedStage::initGCNSchedStage() || DAG.Regions.size() == 1) return false; // Before performing any IR modification record the parent region of each MI @@ -1138,11 +1136,6 @@ void UnclusteredHighRPStage::finalizeGCNSchedStage() { SavedMutations.swap(DAG.Mutations); S.SGPRLimitBias = S.VGPRLimitBias = 0; if (DAG.MinOccupancy > InitialOccupancy) { - for (unsigned IDX = 0; IDX < DAG.Pressure.size(); ++IDX) - DAG.RegionsWithMinOcc[IDX] = - DAG.Pressure[IDX].getOccupancy( - DAG.ST, DAG.MFI.getDynamicVGPRBlockSize()) == DAG.MinOccupancy; - LLVM_DEBUG(dbgs() << StageID << " stage successfully increased occupancy to " << DAG.MinOccupancy << '\n'); @@ -1214,11 +1207,15 @@ bool GCNSchedStage::initGCNRegion() { } bool UnclusteredHighRPStage::initGCNRegion() { - // Only reschedule regions with the minimum occupancy or regions that may have - // spilling (excess register pressure). - if ((!DAG.RegionsWithMinOcc[RegionIdx] || - DAG.MinOccupancy <= InitialOccupancy) && - !DAG.RegionsWithExcessRP[RegionIdx]) + // Only reschedule regions that have excess register pressure (i.e. spilling) + // or had minimum occupancy at the beginning of the stage (as long as + // rescheduling of previous regions did not make occupancy drop back down to + // the initial minimum). + unsigned DynamicVGPRBlockSize = DAG.MFI.getDynamicVGPRBlockSize(); + if (!DAG.RegionsWithExcessRP[RegionIdx] && + (DAG.MinOccupancy <= InitialOccupancy || + DAG.Pressure[RegionIdx].getOccupancy(ST, DynamicVGPRBlockSize) != + InitialOccupancy)) return false; return GCNSchedStage::initGCNRegion(); @@ -1283,9 +1280,6 @@ void GCNSchedStage::checkScheduling() { if (PressureAfter.getSGPRNum() <= S.SGPRCriticalLimit && PressureAfter.getVGPRNum(ST.hasGFX90AInsts()) <= S.VGPRCriticalLimit) { DAG.Pressure[RegionIdx] = PressureAfter; - DAG.RegionsWithMinOcc[RegionIdx] = - PressureAfter.getOccupancy(ST, DynamicVGPRBlockSize) == - DAG.MinOccupancy; // Early out if we have achieved the occupancy target. LLVM_DEBUG(dbgs() << "Pressure in desired limits, done.\n"); @@ -1319,7 +1313,6 @@ void GCNSchedStage::checkScheduling() { if (NewOccupancy < DAG.MinOccupancy) { DAG.MinOccupancy = NewOccupancy; MFI.limitOccupancy(DAG.MinOccupancy); - DAG.RegionsWithMinOcc.reset(); LLVM_DEBUG(dbgs() << "Occupancy lowered for the function to " << DAG.MinOccupancy << ".\n"); } @@ -1341,14 +1334,10 @@ void GCNSchedStage::checkScheduling() { // Revert if this region's schedule would cause a drop in occupancy or // spilling. - if (shouldRevertScheduling(WavesAfter)) { + if (shouldRevertScheduling(WavesAfter)) revertScheduling(); - } else { + else DAG.Pressure[RegionIdx] = PressureAfter; - DAG.RegionsWithMinOcc[RegionIdx] = - PressureAfter.getOccupancy(ST, DynamicVGPRBlockSize) == - DAG.MinOccupancy; - } } unsigned @@ -1578,9 +1567,6 @@ bool GCNSchedStage::mayCauseSpilling(unsigned WavesAfter) { } void GCNSchedStage::revertScheduling() { - DAG.RegionsWithMinOcc[RegionIdx] = - PressureBefore.getOccupancy(ST, DAG.MFI.getDynamicVGPRBlockSize()) == - DAG.MinOccupancy; LLVM_DEBUG(dbgs() << "Attempting to revert scheduling.\n"); DAG.RegionEnd = DAG.RegionBegin; int SkippedDebugInstr = 0; diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h index 94cd795..32139a9 100644 --- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h +++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h @@ -250,9 +250,6 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive { // limit. Register pressure in these regions usually will result in spilling. BitVector RegionsWithExcessRP; - // Regions that has the same occupancy as the latest MinOccupancy - BitVector RegionsWithMinOcc; - // Regions that have IGLP instructions (SCHED_GROUP_BARRIER or IGLP_OPT). BitVector RegionsWithIGLPInstrs; diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp index 0a0a107..0237a60 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp @@ -340,6 +340,43 @@ void GCNSubtarget::overrideSchedPolicy(MachineSchedPolicy &Policy, Policy.ShouldTrackLaneMasks = true; } +void GCNSubtarget::overridePostRASchedPolicy(MachineSchedPolicy &Policy, + const SchedRegion &Region) const { + const Function &F = Region.RegionBegin->getMF()->getFunction(); + Attribute PostRADirectionAttr = F.getFnAttribute("amdgpu-post-ra-direction"); + if (!PostRADirectionAttr.isValid()) + return; + + StringRef PostRADirectionStr = PostRADirectionAttr.getValueAsString(); + if (PostRADirectionStr == "topdown") { + Policy.OnlyTopDown = true; + Policy.OnlyBottomUp = false; + } else if (PostRADirectionStr == "bottomup") { + Policy.OnlyTopDown = false; + Policy.OnlyBottomUp = true; + } else if (PostRADirectionStr == "bidirectional") { + Policy.OnlyTopDown = false; + Policy.OnlyBottomUp = false; + } else { + DiagnosticInfoOptimizationFailure Diag( + F, F.getSubprogram(), "invalid value for postRA direction attribute"); + F.getContext().diagnose(Diag); + } + + LLVM_DEBUG({ + const char *DirStr = "default"; + if (Policy.OnlyTopDown && !Policy.OnlyBottomUp) + DirStr = "topdown"; + else if (!Policy.OnlyTopDown && Policy.OnlyBottomUp) + DirStr = "bottomup"; + else if (!Policy.OnlyTopDown && !Policy.OnlyBottomUp) + DirStr = "bidirectional"; + + dbgs() << "Post-MI-sched direction (" << F.getName() << "): " << DirStr + << '\n'; + }); +} + void GCNSubtarget::mirFileLoaded(MachineFunction &MF) const { if (isWave32()) { // Fix implicit $vcc operands after MIParser has verified that they match diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h index fba1c5a..c84ba1a 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h @@ -236,6 +236,7 @@ protected: bool Has64BitLiterals = false; bool HasBitOp3Insts = false; bool HasTanhInsts = false; + bool HasTensorCvtLutInsts = false; bool HasTransposeLoadF4F6Insts = false; bool HasPrngInst = false; bool HasBVHDualAndBVH8Insts = false; @@ -1041,6 +1042,9 @@ public: void overrideSchedPolicy(MachineSchedPolicy &Policy, const SchedRegion &Region) const override; + void overridePostRASchedPolicy(MachineSchedPolicy &Policy, + const SchedRegion &Region) const override; + void mirFileLoaded(MachineFunction &MF) const override; unsigned getMaxNumUserSGPRs() const { @@ -1408,6 +1412,8 @@ public: bool hasTanhInsts() const { return HasTanhInsts; } + bool hasTensorCvtLutInsts() const { return HasTensorCvtLutInsts; } + bool hasAddPC64Inst() const { return GFX1250Insts; } bool hasMinimum3Maximum3PKF16() const { @@ -1535,6 +1541,9 @@ public: // \returns true if the target has V_{MIN|MAX}_{I|U}64 instructions. bool hasIntMinMax64() const { return GFX1250Insts; } + // \returns true if the target has V_ADD_{MIN|MAX}_{I|U}32 instructions. + bool hasAddMinMaxInsts() const { return GFX1250Insts; } + // \returns true if the target has V_PK_ADD_{MIN|MAX}_{I|U}16 instructions. bool hasPkAddMinMaxInsts() const { return GFX1250Insts; } diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp index 2a920f6..4e4660c 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp @@ -33,8 +33,7 @@ public: AMDGPUAsmBackend(const Target &T) : MCAsmBackend(llvm::endianness::little) {} void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool fixupNeedsRelaxation(const MCFixup &Fixup, uint64_t Value) const override; @@ -129,9 +128,8 @@ static uint64_t adjustFixupValue(const MCFixup &Fixup, uint64_t Value, } void AMDGPUAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (Target.getSpecifier()) IsResolved = false; maybeAddReloc(F, Fixup, Target, Value, IsResolved); @@ -148,13 +146,13 @@ void AMDGPUAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, Value <<= Info.TargetOffset; unsigned NumBytes = getFixupKindNumBytes(Fixup.getKind()); - uint32_t Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the bits from // the fixup value. for (unsigned i = 0; i != NumBytes; ++i) - Data[Offset + i] |= static_cast<uint8_t>((Value >> (i * 8)) & 0xff); + Data[i] |= static_cast<uint8_t>((Value >> (i * 8)) & 0xff); } std::optional<MCFixupKind> diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp index 11b072e..42c4d8b 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp @@ -540,6 +540,8 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType, printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O)) return; break; + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: + break; default: llvm_unreachable("bad operand type"); } @@ -770,6 +772,7 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo, case AMDGPU::OPERAND_REG_IMM_V2INT16: case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_IMM_V2FP16: + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2INT16: case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: @@ -1790,4 +1793,14 @@ void AMDGPUInstPrinter::printBitOp3(const MCInst *MI, unsigned OpNo, O << formatHex(static_cast<uint64_t>(Imm)); } +void AMDGPUInstPrinter::printScaleSel(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O) { + uint8_t Imm = MI->getOperand(OpNo).getImm(); + if (!Imm) + return; + + O << " scale_sel:" << formatDec(Imm); +} + #include "AMDGPUGenAsmWriter.inc" diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h index e0b7aa5..f6739b14 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h @@ -173,6 +173,8 @@ private: const MCSubtargetInfo &STI, raw_ostream &O, StringRef Prefix, bool PrintInHex, bool AlwaysPrint); + void printScaleSel(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O); void printBitOp3(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI, raw_ostream &O); diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp index c49ad79..f358084 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp @@ -341,6 +341,9 @@ std::optional<uint64_t> AMDGPUMCCodeEmitter::getLitEncoding( return AMDGPU::getInlineEncodingV2BF16(static_cast<uint32_t>(Imm)) .value_or(255); + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: + return 255; + case AMDGPU::OPERAND_KIMM32: case AMDGPU::OPERAND_KIMM16: case AMDGPU::OPERAND_KIMM64: diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h index 40b8bcd..c564145 100644 --- a/llvm/lib/Target/AMDGPU/SIDefines.h +++ b/llvm/lib/Target/AMDGPU/SIDefines.h @@ -208,6 +208,7 @@ enum OperandType : unsigned { OPERAND_REG_IMM_V2BF16, OPERAND_REG_IMM_V2FP16, OPERAND_REG_IMM_V2INT16, + OPERAND_REG_IMM_NOINLINE_V2FP16, OPERAND_REG_IMM_V2INT32, OPERAND_REG_IMM_V2FP32, diff --git a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp index b77da4d..e934152 100644 --- a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp +++ b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp @@ -468,6 +468,7 @@ bool SIFoldOperandsImpl::canUseImmWithOpSel(const MachineInstr *MI, case AMDGPU::OPERAND_REG_IMM_V2FP16: case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_IMM_V2INT16: + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: case AMDGPU::OPERAND_REG_INLINE_C_V2INT16: diff --git a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp index 11552b3..9b348d4 100644 --- a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp @@ -983,6 +983,7 @@ void SIFrameLowering::emitCSRSpillStores( const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); const SIInstrInfo *TII = ST.getInstrInfo(); const SIRegisterInfo &TRI = TII->getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); // Spill Whole-Wave Mode VGPRs. Save only the inactive lanes of the scratch // registers. However, save all lanes of callee-saved VGPRs. Due to this, we @@ -1005,6 +1006,12 @@ void SIFrameLowering::emitCSRSpillStores( } }; + for (const Register Reg : make_first_range(WWMScratchRegs)) { + if (!MRI.isReserved(Reg)) { + MRI.addLiveIn(Reg); + MBB.addLiveIn(Reg); + } + } StoreWWMRegisters(WWMScratchRegs); auto EnableAllLanes = [&]() { diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 579ca96..4d67e4a 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -4292,7 +4292,7 @@ SDValue SITargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, Chain = BaseAddr.getValue(1); Align StackAlign = TFL->getStackAlign(); if (Alignment > StackAlign) { - uint64_t ScaledAlignment = (uint64_t)Alignment.value() + uint64_t ScaledAlignment = Alignment.value() << Subtarget->getWavefrontSizeLog2(); uint64_t StackAlignMask = ScaledAlignment - 1; SDValue TmpAddr = DAG.getNode(ISD::ADD, dl, VT, BaseAddr, @@ -7199,7 +7199,7 @@ SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const { SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16); return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc); } - if (getTargetMachine().Options.UnsafeFPMath) { + if (Op->getFlags().hasApproximateFuncs()) { SDValue Flags = Op.getOperand(1); SDValue Src32 = DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Src, Flags); return DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, Src32, Flags); @@ -11294,8 +11294,7 @@ SDValue SITargetLowering::lowerFastUnsafeFDIV(SDValue Op, EVT VT = Op.getValueType(); const SDNodeFlags Flags = Op->getFlags(); - bool AllowInaccurateRcp = - Flags.hasApproximateFuncs() || DAG.getTarget().Options.UnsafeFPMath; + bool AllowInaccurateRcp = Flags.hasApproximateFuncs(); if (const ConstantFPSDNode *CLHS = dyn_cast<ConstantFPSDNode>(LHS)) { // Without !fpmath accuracy information, we can't do more because we don't @@ -11314,7 +11313,7 @@ SDValue SITargetLowering::lowerFastUnsafeFDIV(SDValue Op, // 1.0 / sqrt(x) -> rsq(x) - // XXX - Is UnsafeFPMath sufficient to do this for f64? The maximum ULP + // XXX - Is afn sufficient to do this for f64? The maximum ULP // error seems really high at 2^29 ULP. // 1.0 / x -> rcp(x) return DAG.getNode(AMDGPUISD::RCP, SL, VT, RHS); @@ -11348,8 +11347,7 @@ SDValue SITargetLowering::lowerFastUnsafeFDIV64(SDValue Op, EVT VT = Op.getValueType(); const SDNodeFlags Flags = Op->getFlags(); - bool AllowInaccurateDiv = - Flags.hasApproximateFuncs() || DAG.getTarget().Options.UnsafeFPMath; + bool AllowInaccurateDiv = Flags.hasApproximateFuncs(); if (!AllowInaccurateDiv) return SDValue(); @@ -14601,7 +14599,7 @@ unsigned SITargetLowering::getFusedOpcode(const SelectionDAG &DAG, return ISD::FMAD; const TargetOptions &Options = DAG.getTarget().Options; - if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath || + if ((Options.AllowFPOpFusion == FPOpFusion::Fast || (N0->getFlags().hasAllowContract() && N1->getFlags().hasAllowContract())) && isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT)) { @@ -15724,9 +15722,9 @@ SDValue SITargetLowering::performFMACombine(SDNode *N, // fdot2_f32_f16 always flushes fp32 denormal operand and output to zero, // regardless of the denorm mode setting. Therefore, - // unsafe-fp-math/fp-contract is sufficient to allow generating fdot2. + // fp-contract is sufficient to allow generating fdot2. const TargetOptions &Options = DAG.getTarget().Options; - if (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath || + if (Options.AllowFPOpFusion == FPOpFusion::Fast || (N->getFlags().hasAllowContract() && FMA->getFlags().hasAllowContract())) { Op1 = Op1.getOperand(0); @@ -16827,56 +16825,51 @@ SITargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI_, return std::pair(0U, RC); } - if (Constraint.starts_with("{") && Constraint.ends_with("}")) { - StringRef RegName(Constraint.data() + 1, Constraint.size() - 2); - if (RegName.consume_front("v")) { + auto [Kind, Idx, NumRegs] = AMDGPU::parseAsmConstraintPhysReg(Constraint); + if (Kind != '\0') { + if (Kind == 'v') { RC = &AMDGPU::VGPR_32RegClass; - } else if (RegName.consume_front("s")) { + } else if (Kind == 's') { RC = &AMDGPU::SGPR_32RegClass; - } else if (RegName.consume_front("a")) { + } else if (Kind == 'a') { RC = &AMDGPU::AGPR_32RegClass; } if (RC) { - uint32_t Idx; - if (RegName.consume_front("[")) { - uint32_t End; - bool Failed = RegName.consumeInteger(10, Idx); - Failed |= !RegName.consume_front(":"); - Failed |= RegName.consumeInteger(10, End); - Failed |= !RegName.consume_back("]"); - if (!Failed) { - uint32_t Width = (End - Idx + 1) * 32; - // Prohibit constraints for register ranges with a width that does not - // match the required type. - if (VT.SimpleTy != MVT::Other && Width != VT.getSizeInBits()) + if (NumRegs > 1) { + if (Idx >= RC->getNumRegs() || Idx + NumRegs - 1 > RC->getNumRegs()) + return std::pair(0U, nullptr); + + uint32_t Width = NumRegs * 32; + // Prohibit constraints for register ranges with a width that does not + // match the required type. + if (VT.SimpleTy != MVT::Other && Width != VT.getSizeInBits()) + return std::pair(0U, nullptr); + + MCRegister Reg = RC->getRegister(Idx); + if (SIRegisterInfo::isVGPRClass(RC)) + RC = TRI->getVGPRClassForBitWidth(Width); + else if (SIRegisterInfo::isSGPRClass(RC)) + RC = TRI->getSGPRClassForBitWidth(Width); + else if (SIRegisterInfo::isAGPRClass(RC)) + RC = TRI->getAGPRClassForBitWidth(Width); + if (RC) { + Reg = TRI->getMatchingSuperReg(Reg, AMDGPU::sub0, RC); + if (!Reg) { + // The register class does not contain the requested register, + // e.g., because it is an SGPR pair that would violate alignment + // requirements. return std::pair(0U, nullptr); - MCRegister Reg = RC->getRegister(Idx); - if (SIRegisterInfo::isVGPRClass(RC)) - RC = TRI->getVGPRClassForBitWidth(Width); - else if (SIRegisterInfo::isSGPRClass(RC)) - RC = TRI->getSGPRClassForBitWidth(Width); - else if (SIRegisterInfo::isAGPRClass(RC)) - RC = TRI->getAGPRClassForBitWidth(Width); - if (RC) { - Reg = TRI->getMatchingSuperReg(Reg, AMDGPU::sub0, RC); - if (!Reg) { - // The register class does not contain the requested register, - // e.g., because it is an SGPR pair that would violate alignment - // requirements. - return std::pair(0U, nullptr); - } - return std::pair(Reg, RC); } + return std::pair(Reg, RC); } - } else { - // Check for lossy scalar/vector conversions. - if (VT.isVector() && VT.getSizeInBits() != 32) - return std::pair(0U, nullptr); - bool Failed = RegName.getAsInteger(10, Idx); - if (!Failed && Idx < RC->getNumRegs()) - return std::pair(RC->getRegister(Idx), RC); } + + // Check for lossy scalar/vector conversions. + if (VT.isVector() && VT.getSizeInBits() != 32) + return std::pair(0U, nullptr); + if (Idx < RC->getNumRegs()) + return std::pair(RC->getRegister(Idx), RC); } } diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index c2da937..3f61bbd 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -4438,6 +4438,8 @@ bool SIInstrInfo::isInlineConstant(int64_t Imm, uint8_t OperandType) const { case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: return AMDGPU::isInlinableLiteralV2BF16(Imm); + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: + return false; case AMDGPU::OPERAND_REG_IMM_FP16: case AMDGPU::OPERAND_REG_INLINE_C_FP16: { if (isInt<16>(Imm) || isUInt<16>(Imm)) { @@ -6302,10 +6304,14 @@ void SIInstrInfo::legalizeOperandsVOP3(MachineRegisterInfo &MRI, }; if (Opc == AMDGPU::V_PERMLANE16_B32_e64 || - Opc == AMDGPU::V_PERMLANEX16_B32_e64) { + Opc == AMDGPU::V_PERMLANEX16_B32_e64 || + Opc == AMDGPU::V_PERMLANE_BCAST_B32_e64 || + Opc == AMDGPU::V_PERMLANE_UP_B32_e64 || + Opc == AMDGPU::V_PERMLANE_DOWN_B32_e64 || + Opc == AMDGPU::V_PERMLANE_XOR_B32_e64 || + Opc == AMDGPU::V_PERMLANE_IDX_GEN_B32_e64) { // src1 and src2 must be scalar MachineOperand &Src1 = MI.getOperand(VOP3Idx[1]); - MachineOperand &Src2 = MI.getOperand(VOP3Idx[2]); const DebugLoc &DL = MI.getDebugLoc(); if (Src1.isReg() && !RI.isSGPRClass(MRI.getRegClass(Src1.getReg()))) { Register Reg = MRI.createVirtualRegister(&AMDGPU::SReg_32_XM0RegClass); @@ -6313,11 +6319,14 @@ void SIInstrInfo::legalizeOperandsVOP3(MachineRegisterInfo &MRI, .add(Src1); Src1.ChangeToRegister(Reg, false); } - if (Src2.isReg() && !RI.isSGPRClass(MRI.getRegClass(Src2.getReg()))) { - Register Reg = MRI.createVirtualRegister(&AMDGPU::SReg_32_XM0RegClass); - BuildMI(*MI.getParent(), MI, DL, get(AMDGPU::V_READFIRSTLANE_B32), Reg) - .add(Src2); - Src2.ChangeToRegister(Reg, false); + if (VOP3Idx[2] != -1) { + MachineOperand &Src2 = MI.getOperand(VOP3Idx[2]); + if (Src2.isReg() && !RI.isSGPRClass(MRI.getRegClass(Src2.getReg()))) { + Register Reg = MRI.createVirtualRegister(&AMDGPU::SReg_32_XM0RegClass); + BuildMI(*MI.getParent(), MI, DL, get(AMDGPU::V_READFIRSTLANE_B32), Reg) + .add(Src2); + Src2.ChangeToRegister(Reg, false); + } } } diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td index 83b0490..4698a58 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td @@ -1313,6 +1313,10 @@ def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">; def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">; def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">; +def ScaleSel : NamedIntOperand<"scale_sel"> { + let Validator = "isUInt<3>"; +} + class KImmFPOperand<ValueType vt> : ImmOperand<vt> { let OperandNamespace = "AMDGPU"; let OperandType = "OPERAND_KIMM"#vt.Size; @@ -1770,6 +1774,7 @@ class getVALUDstForVT<ValueType VT, bit IsTrue16 = 0, bit IsVOP3Encoding = 0> { !eq(VT.Size, 256) : VOPDstOperand<VReg_256>, !eq(VT.Size, 192) : VOPDstOperand<VReg_192>, !eq(VT.Size, 128) : VOPDstOperand<VReg_128>, + !eq(VT.Size, 96) : VOPDstOperand<VReg_96>, !eq(VT.Size, 64) : VOPDstOperand<VReg_64>, !eq(VT.Size, 32) : VOPDstOperand<VGPR_32>, !eq(VT.Size, 16) : op16, @@ -1920,6 +1925,7 @@ class getVOP3DPPSrcForVT<ValueType VT, bit IsFake16 = 1> { !eq(VT, v2f16) : VCSrc_v2f16, !eq(VT, v2bf16) : VCSrc_v2bf16, !eq(VT, f32) : VCSrc_f32, + !eq(VT, v2i32) : VCSrc_v2b32, 1 : VCSrc_b32); } @@ -2859,6 +2865,7 @@ def VOP_I16_F16 : VOPProfile <[i16, f16, untyped, untyped]>; def VOP_I16_I16 : VOPProfile <[i16, i16, untyped, untyped]>; def VOP_BF16_BF16 : VOPProfile<[bf16, bf16, untyped, untyped]>; def VOP1_I16_I32 : VOPProfile<[i16, i32, untyped, untyped]>; +def VOP_I16_V2F16 : VOPProfile<[i16, v2f16, untyped, untyped]>; def VOP_F16_F16_F16 : VOPProfile <[f16, f16, f16, untyped]>; def VOP_F16_F16_I16 : VOPProfile <[f16, f16, i16, untyped]>; @@ -2926,8 +2933,13 @@ def VOP_V2BF16_F32_F32 : VOPProfile <[v2bf16, f32, f32, untyped]>; def VOP_V32F32_V6I32_F32 : VOPProfile <[v32f32, v6i32, f32, untyped]>; def VOP_V32F16_V6I32_F32 : VOPProfile <[v32f16, v6i32, f32, untyped]>; def VOP_V32BF16_V6I32_F32 : VOPProfile <[v32bf16, v6i32, f32, untyped]>; +def VOP_V2BF16_F32_F32_I32 : VOPProfile <[v2bf16, f32, f32, i32]>; +def VOP_V2F16_F32_F32_I32 : VOPProfile <[v2f16, f32, f32, i32]>; def VOP_V6I32_V32F16_F32 : VOPProfile<[v6i32, v32f16, f32, untyped]>; def VOP_V6I32_V32BF16_F32 : VOPProfile<[v6i32, v32bf16, f32, untyped]>; +def VOP_V3I32_V16F16_F32 : VOPProfile<[v3i32, v16f16, f32, untyped]>; +def VOP_V3I32_V16BF16_F32 : VOPProfile<[v3i32, v16bf16, f32, untyped]>; +def VOP_V3I32_V16F32_F32 : VOPProfile<[v3i32, v16f32, f32, untyped]>; def VOP_V6I32_V16F32_V16F32_F32 : VOPProfile<[v6i32, v16f32, v16f32, f32]>; def VOP_V2F16_I32_F32 : VOPProfile<[v2f16, i32, f32, untyped]>; def VOP_V2I16_F32_F32_F32 : VOPProfile<[v2i16, f32, f32, f32]>; @@ -2941,11 +2953,35 @@ def VOP_BF16_F32_I32 : VOPProfile<[bf16, f32, i32, untyped]>; def VOP_F16_F32_I32 : VOPProfile<[f16, f32, i32, untyped]>; def VOP_I32_BF16_I32_F32 : VOPProfile<[i32, bf16, i32, f32]>; def VOP_I32_F16_I32_F32 : VOPProfile<[i32, f16, i32, f32]>; +def VOP_V16F16_V3I32_I32 : VOPProfile<[v16f16, v3i32, i32, untyped]>; +def VOP_V16BF16_V3I32_I32 : VOPProfile<[v16bf16, v3i32, i32, untyped]>; +def VOP_V8F16_V2I32_I32 : VOPProfile<[v8f16, v2i32, i32, untyped]>; +def VOP_V8BF16_V2I32_I32 : VOPProfile<[v8bf16, v2i32, i32, untyped]>; +def VOP_V8F16_I32_I32 : VOPProfile<[v8f16, i32, i32, untyped]>; +def VOP_V8BF16_I32_I32 : VOPProfile<[v8bf16, i32, i32, untyped]>; +def VOP_V16F32_V3I32_I32 : VOPProfile<[v16f32, v3i32, i32, untyped]>; +def VOP_V8F32_V2I32_I32 : VOPProfile<[v8f32, v2i32, i32, untyped]>; +def VOP_V8F32_I32_I32 : VOPProfile<[v8f32, i32, i32, untyped]>; +def VOP_V2I32_V8BF16_F32 : VOPProfile<[v2i32, v8bf16, f32, untyped]>; +def VOP_V2I32_V8F16_F32 : VOPProfile<[v2i32, v8f16, f32, untyped]>; +def VOP_V2I32_V8F32_F32 : VOPProfile<[v2i32, v8f32, f32, untyped]>; +def VOP_I32_V8F32_F32 : VOPProfile<[i32, v8f32, f32, untyped]>; +def VOP_I32_V8F16_F32 : VOPProfile<[i32, v8f16, f32, untyped]>; +def VOP_I32_V8BF16_F32 : VOPProfile<[i32, v8bf16, f32, untyped]>; def VOP_I32_F32_I32_F32 : VOPProfile<[i32, f32, i32, f32]>; def VOP_V6I32_V32BF16_I32_F32 : VOPProfile<[v6i32, v32bf16, i32, f32]>; def VOP_V6I32_V32F16_I32_F32 : VOPProfile<[v6i32, v32f16, i32, f32]>; def VOP_V6I32_V32F32_I32_F32 : VOPProfile<[v6i32, v32f32, i32, f32]>; +def VOP_V3I32_V16F16_I32_F32 : VOPProfile<[v3i32, v16f16, i32, f32]>; +def VOP_V3I32_V16BF16_I32_F32 : VOPProfile<[v3i32, v16bf16, i32, f32]>; +def VOP_V3I32_V16F32_I32_F32 : VOPProfile<[v3i32, v16f32, i32, f32]>; +def VOP_V2I32_V8BF16_I32_F32 : VOPProfile<[v2i32, v8bf16, i32, f32]>; +def VOP_V2I32_V8F16_I32_F32 : VOPProfile<[v2i32, v8f16, i32, f32]>; +def VOP_V2I32_V8F32_I32_F32 : VOPProfile<[v2i32, v8f32, i32, f32]>; +def VOP_I32_V8F32_I32_F32 : VOPProfile<[i32, v8f32, i32, f32]>; +def VOP_I32_V8F16_I32_F32 : VOPProfile<[i32, v8f16, i32, f32]>; +def VOP_I32_V8BF16_I32_F32 : VOPProfile<[i32, v8bf16, i32, f32]>; def VOP_I64_I64_I32 : VOPProfile <[i64, i64, i32, untyped]>; def VOP_I64_I32_I64 : VOPProfile <[i64, i32, i64, untyped]>; diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td index 218841d..08d07c9 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td @@ -1218,6 +1218,8 @@ def VSrc_f64 : SrcRegOrImm9 <VS_64, "OPERAND_REG_IMM_FP64"> { def VSrc_v2b32 : SrcRegOrImm9 <VS_64, "OPERAND_REG_IMM_V2INT32">; def VSrc_v2f32 : SrcRegOrImm9 <VS_64, "OPERAND_REG_IMM_V2FP32">; +def VSrc_NoInline_v2f16 : SrcRegOrImm9 <VS_32, "OPERAND_REG_IMM_NOINLINE_V2FP16">; + //===----------------------------------------------------------------------===// // VRegSrc_* Operands with a VGPR //===----------------------------------------------------------------------===// @@ -1300,6 +1302,7 @@ def VCSrc_f64 : SrcRegOrImm9 <VS_64, "OPERAND_REG_INLINE_C_FP64">; def VCSrc_v2b16 : SrcRegOrImm9 <VS_32, "OPERAND_REG_INLINE_C_V2INT16">; def VCSrc_v2bf16: SrcRegOrImm9 <VS_32, "OPERAND_REG_INLINE_C_V2BF16">; def VCSrc_v2f16 : SrcRegOrImm9 <VS_32, "OPERAND_REG_INLINE_C_V2FP16">; +def VCSrc_v2b32 : SrcRegOrImm9 <VS_64, "OPERAND_REG_INLINE_C_V2INT32">; // True 16 Operands def VCSrcT_b16 : SrcRegOrImm9_t16 <"OPERAND_REG_INLINE_C_INT16">; diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp index 83e63ac..65fa088 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp @@ -1548,6 +1548,42 @@ bool shouldEmitConstantsToTextSection(const Triple &TT) { return TT.getArch() == Triple::r600; } +static bool isValidRegPrefix(char C) { + return C == 'v' || C == 's' || C == 'a'; +} + +std::tuple<char, unsigned, unsigned> +parseAsmConstraintPhysReg(StringRef Constraint) { + StringRef RegName = Constraint; + if (!RegName.consume_front("{") || !RegName.consume_back("}")) + return {}; + + char Kind = RegName.front(); + if (!isValidRegPrefix(Kind)) + return {}; + + RegName = RegName.drop_front(); + if (RegName.consume_front("[")) { + unsigned Idx, End; + bool Failed = RegName.consumeInteger(10, Idx); + Failed |= !RegName.consume_front(":"); + Failed |= RegName.consumeInteger(10, End); + Failed |= !RegName.consume_back("]"); + if (!Failed) { + unsigned NumRegs = End - Idx + 1; + if (NumRegs > 1) + return {Kind, Idx, NumRegs}; + } + } else { + unsigned Idx; + bool Failed = RegName.getAsInteger(10, Idx); + if (!Failed) + return {Kind, Idx, 1}; + } + + return {}; +} + std::pair<unsigned, unsigned> getIntegerPairAttribute(const Function &F, StringRef Name, std::pair<unsigned, unsigned> Default, @@ -2659,6 +2695,7 @@ bool isSISrcFPOperand(const MCInstrDesc &Desc, unsigned OpNo) { case AMDGPU::OPERAND_REG_IMM_FP64: case AMDGPU::OPERAND_REG_IMM_FP16: case AMDGPU::OPERAND_REG_IMM_V2FP16: + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: case AMDGPU::OPERAND_REG_INLINE_C_FP32: case AMDGPU::OPERAND_REG_INLINE_C_FP64: case AMDGPU::OPERAND_REG_INLINE_C_FP16: @@ -3023,6 +3060,8 @@ bool isInlinableLiteralV216(uint32_t Literal, uint8_t OpType) { case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: return isInlinableLiteralV2BF16(Literal); + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: + return false; default: llvm_unreachable("bad packed operand type"); } diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h index c09a9d6..1252e35 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h @@ -1012,6 +1012,12 @@ bool isReadOnlySegment(const GlobalValue *GV); /// target triple \p TT, false otherwise. bool shouldEmitConstantsToTextSection(const Triple &TT); +/// Returns a valid charcode or 0 in the first entry if this is a valid physical +/// register constraint. Followed by the start register number, and the register +/// width. Does not validate the number of registers exists in the class. +std::tuple<char, unsigned, unsigned> +parseAsmConstraintPhysReg(StringRef Constraint); + /// \returns Integer value requested using \p F's \p Name attribute. /// /// \returns \p Default if attribute is not present. @@ -1636,6 +1642,7 @@ inline unsigned getOperandSize(const MCOperandInfo &OpInfo) { case AMDGPU::OPERAND_REG_IMM_V2INT16: case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_IMM_V2FP16: + case AMDGPU::OPERAND_REG_IMM_NOINLINE_V2FP16: return 2; default: diff --git a/llvm/lib/Target/AMDGPU/VOP2Instructions.td b/llvm/lib/Target/AMDGPU/VOP2Instructions.td index 550ec9d..9de7d6d 100644 --- a/llvm/lib/Target/AMDGPU/VOP2Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP2Instructions.td @@ -1344,6 +1344,8 @@ def V_FMAAK_F64 : VOP2_Pseudo<"v_fmaak_f64", VOP_MADAK_F64, [], "">; } // End SubtargetPredicate = HasFmaakFmamkF64Insts, isReMaterializable = 1, FixedSize = 1, Size = 12, SchedRW = [Write64Bit] let SubtargetPredicate = HasPkFmacF16Inst in { +// FIXME: V_PK_FMAC_F16 is currently not used in instruction selection. +// If this changes, ensure the DPP variant is not used for GFX11+. defm V_PK_FMAC_F16 : VOP2Inst<"v_pk_fmac_f16", VOP_V2F16_V2F16_V2F16>; } // End SubtargetPredicate = HasPkFmacF16Inst @@ -1904,7 +1906,7 @@ multiclass VOP2_Real_FULL_with_name_gfx11_gfx12<bits<6> op, string opName, VOP2_Real_FULL_with_name<GFX12Gen, op, opName, asmName>; multiclass VOP2_Real_e32_gfx11_gfx12<bits<6> op> : - VOP2Only_Real<GFX11Gen, op>, VOP2Only_Real<GFX12Gen, op>; + VOP2Only_Real_e32<GFX11Gen, op>, VOP2Only_Real_e32<GFX12Gen, op>; multiclass VOP3Only_Realtriple_gfx11_gfx12<bits<10> op> : VOP3Only_Realtriple<GFX11Gen, op>, VOP3Only_Realtriple<GFX12Gen, op>; diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 5586dd8..f4b6af6 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -625,8 +625,9 @@ def shl_0_to_4 : PatFrag< }]; } -def VOP3_CVT_PK_F8_F32_Profile : VOP3_Profile<VOP_I32_F32_F32, VOP3_OPSEL> { - defvar Tail = (ins VGPR_32:$vdst_in, op_sel0:$op_sel); +class VOP3_CVT_PK_F8_F32_Profile<bit _HasClamp = 0> : VOP3_Profile<VOP_I32_F32_F32, VOP3_OPSEL> { + defvar Tail = !con(!if(_HasClamp, (ins Clamp:$clamp), (ins)), + (ins VGPR_32:$vdst_in, op_sel0:$op_sel)); let InsVOP3OpSel = !con(getIns64<Src0RC64, Src1RC64, Src2RC64, NumSrcArgs, 0, HasModifiers, HasSrc2Mods, HasOMod, Src0Mod, Src1Mod, Src2Mod>.ret, @@ -636,12 +637,13 @@ def VOP3_CVT_PK_F8_F32_Profile : VOP3_Profile<VOP_I32_F32_F32, VOP3_OPSEL> { HasSrc2Mods, HasOMod, Src0ModVOP3DPP, Src1ModVOP3DPP, Src2ModVOP3DPP, false>.ret, Tail); - let HasClamp = 0; + let HasClamp = _HasClamp; let HasExtVOP3DPP = 1; } -def VOP3_CVT_PK_F8_F32_Profile_fake16 : VOP3_Profile_Fake16<VOP_I16_F32_F32, VOP3_OPSEL> { - defvar Tail = (ins VGPR_32:$vdst_in, op_sel0:$op_sel); +class VOP3_CVT_PK_F8_F32_Profile_fake16<bit _HasClamp = 0> : VOP3_Profile_Fake16<VOP_I16_F32_F32, VOP3_OPSEL> { + defvar Tail = !con(!if(_HasClamp, (ins Clamp:$clamp), (ins)), + (ins VGPR_32:$vdst_in, op_sel0:$op_sel)); let InsVOP3OpSel = !con(getIns64<Src0RC64, Src1RC64, Src2RC64, NumSrcArgs, 0, HasModifiers, HasSrc2Mods, HasOMod, Src0Mod, Src1Mod, Src2Mod>.ret, @@ -651,14 +653,15 @@ def VOP3_CVT_PK_F8_F32_Profile_fake16 : VOP3_Profile_Fake16<VOP_I16_F32_F32, VOP HasSrc2Mods, HasOMod, Src0ModVOP3DPP, Src1ModVOP3DPP, Src2ModVOP3DPP, false>.ret, Tail); - let HasClamp = 0; + let HasClamp = _HasClamp; let HasExtVOP3DPP = 1; } // This t16 profile with vdst_in operand is for backward compatibility and is used // for user controlled packing -def VOP3_CVT_PK_F8_F32_Profile_t16 : VOP3_Profile_True16<VOP_I16_F32_F32, VOP3_OPSEL> { - defvar Tail = (ins VGPR_16:$vdst_in, op_sel0:$op_sel); +class VOP3_CVT_PK_F8_F32_Profile_t16<bit _HasClamp = 0> : VOP3_Profile_True16<VOP_I16_F32_F32, VOP3_OPSEL> { + defvar Tail = !con(!if(_HasClamp, (ins Clamp:$clamp), (ins)), + (ins VGPR_16:$vdst_in, op_sel0:$op_sel)); let InsVOP3OpSel = !con(getIns64<Src0RC64, Src1RC64, Src2RC64, NumSrcArgs, 0, HasModifiers, HasSrc2Mods, HasOMod, Src0Mod, Src1Mod, Src2Mod>.ret, @@ -668,7 +671,7 @@ def VOP3_CVT_PK_F8_F32_Profile_t16 : VOP3_Profile_True16<VOP_I16_F32_F32, VOP3_O HasSrc2Mods, HasOMod, Src0ModVOP3DPP, Src1ModVOP3DPP, Src2ModVOP3DPP, false>.ret, Tail); - let HasClamp = 0; + let HasClamp = _HasClamp; let HasExtVOP3DPP = 1; } @@ -702,10 +705,10 @@ def VOP3_CVT_SR_F8_F32_Profile : VOP3_Profile<VOPProfile<[i32, f32, i32, f32]>, HasModifiers, DstVT>.ret); } -class VOP3_CVT_SR_F8_ByteSel_Profile<ValueType SrcVT> : +class VOP3_CVT_SR_F8_ByteSel_Profile<ValueType SrcVT, bit _HasClamp = 0> : VOP3_Profile<VOPProfile<[i32, SrcVT, i32, untyped]>> { let HasFP8DstByteSel = 1; - let HasClamp = 0; + let HasClamp = _HasClamp; } def IsPow2Plus1: PatLeaf<(i32 imm), [{ @@ -746,6 +749,13 @@ let SubtargetPredicate = HasMinimum3Maximum3F16, ReadsModeReg = 0 in { defm V_MAXIMUM3_F16 : VOP3Inst_t16 <"v_maximum3_f16", VOP_F16_F16_F16_F16, AMDGPUfmaximum3>; } // End SubtargetPredicate = isGFX12Plus, ReadsModeReg = 0 +let SubtargetPredicate = HasAddMinMaxInsts, isCommutable = 1, isReMaterializable = 1 in { + defm V_ADD_MAX_I32 : VOP3Inst <"v_add_max_i32", VOP_I32_I32_I32_I32>; + defm V_ADD_MAX_U32 : VOP3Inst <"v_add_max_u32", VOP_I32_I32_I32_I32>; + defm V_ADD_MIN_I32 : VOP3Inst <"v_add_min_i32", VOP_I32_I32_I32_I32>; + defm V_ADD_MIN_U32 : VOP3Inst <"v_add_min_u32", VOP_I32_I32_I32_I32>; +} + defm V_ADD_I16 : VOP3Inst_t16 <"v_add_i16", VOP_I16_I16_I16>; defm V_SUB_I16 : VOP3Inst_t16 <"v_sub_i16", VOP_I16_I16_I16>; @@ -773,15 +783,23 @@ defm V_LSHL_ADD_U64 : VOP3Inst <"v_lshl_add_u64", V_LSHL_ADD_U64_PROF>; let OtherPredicates = [HasFP8ConversionInsts], mayRaiseFPException = 0, SchedRW = [WriteFloatCvt] in { let Constraints = "$vdst = $vdst_in", DisableEncoding = "$vdst_in" in { - defm V_CVT_PK_FP8_F32 : VOP3Inst_t16_with_profiles<"v_cvt_pk_fp8_f32", VOP3_CVT_PK_F8_F32_Profile, - VOP3_CVT_PK_F8_F32_Profile_t16, - VOP3_CVT_PK_F8_F32_Profile_fake16>; - defm V_CVT_PK_BF8_F32 : VOP3Inst_t16_with_profiles<"v_cvt_pk_bf8_f32", VOP3_CVT_PK_F8_F32_Profile, - VOP3_CVT_PK_F8_F32_Profile_t16, - VOP3_CVT_PK_F8_F32_Profile_fake16>; + let OtherPredicates = [HasFP8ConversionInsts, NotHasFP8E5M3Insts] in + defm V_CVT_PK_FP8_F32 : VOP3Inst_t16_with_profiles<"v_cvt_pk_fp8_f32", VOP3_CVT_PK_F8_F32_Profile<>, + VOP3_CVT_PK_F8_F32_Profile_t16<>, + VOP3_CVT_PK_F8_F32_Profile_fake16<>>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in + defm V_CVT_PK_FP8_F32_gfx1250 : VOP3Inst_t16_with_profiles<"v_cvt_pk_fp8_f32_gfx1250", VOP3_CVT_PK_F8_F32_Profile<true>, + VOP3_CVT_PK_F8_F32_Profile_t16<true>, + VOP3_CVT_PK_F8_F32_Profile_fake16<true>>; + defm V_CVT_PK_BF8_F32 : VOP3Inst_t16_with_profiles<"v_cvt_pk_bf8_f32", VOP3_CVT_PK_F8_F32_Profile<>, + VOP3_CVT_PK_F8_F32_Profile_t16<>, + VOP3_CVT_PK_F8_F32_Profile_fake16<>>; let SubtargetPredicate = isGFX12Plus in { - defm V_CVT_SR_FP8_F32_gfx12 : VOP3Inst<"v_cvt_sr_fp8_f32_gfx12", VOP3_CVT_SR_F8_ByteSel_Profile<f32>>; + let OtherPredicates = [HasFP8ConversionInsts, NotHasFP8E5M3Insts] in + defm V_CVT_SR_FP8_F32_gfx12 : VOP3Inst<"v_cvt_sr_fp8_f32_gfx12", VOP3_CVT_SR_F8_ByteSel_Profile<f32>>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in + defm V_CVT_SR_FP8_F32_gfx1250 : VOP3Inst<"v_cvt_sr_fp8_f32_gfx1250", VOP3_CVT_SR_F8_ByteSel_Profile<f32, true>>; defm V_CVT_SR_BF8_F32_gfx12 : VOP3Inst<"v_cvt_sr_bf8_f32_gfx12", VOP3_CVT_SR_F8_ByteSel_Profile<f32>>; } } @@ -800,6 +818,11 @@ class Cvt_PK_F8_F32_Pat<SDPatternOperator node, int index, VOP3_Pseudo inst> : G (inst !if(index, SRCMODS.DST_OP_SEL, 0), $src0, 0, $src1, $old, 0) >; +class Cvt_PK_F8_F32_E5M3_Pat<SDPatternOperator node, int index, VOP3_Pseudo inst, int Clamp> : GCNPat< + (i32 (node f32:$src0, f32:$src1, i32:$old, index)), + (inst !if(index, SRCMODS.DST_OP_SEL, 0), $src0, 0, $src1, Clamp, $old, 0) +>; + multiclass Cvt_PK_F8_F32_t16_Pat<SDPatternOperator node, VOP3_Pseudo inst> { def : GCNPat< (i32 (node f32:$src0, f32:$src1, i32:$old, -1)), @@ -815,6 +838,21 @@ def : GCNPat< >; } +multiclass Cvt_PK_F8_F32_E5M3_t16_Pat<SDPatternOperator node, VOP3_Pseudo inst, int Clamp> { +def : GCNPat< + (i32 (node f32:$src0, f32:$src1, i32:$old, -1)), + (REG_SEQUENCE VGPR_32, + (i16 (EXTRACT_SUBREG $old, lo16)), lo16, + (i16 (inst SRCMODS.DST_OP_SEL, $src0, 0, $src1, Clamp, (i16 (EXTRACT_SUBREG $old, hi16)), 0)), hi16) +>; +def : GCNPat< + (i32 (node f32:$src0, f32:$src1, i32:$old, 0)), + (REG_SEQUENCE VGPR_32, + (i16 (inst 0, $src0, 0, $src1, Clamp, (i16 (EXTRACT_SUBREG $old, lo16)), 0)), lo16, + (i16 (EXTRACT_SUBREG $old, hi16)), hi16) +>; +} + class Cvt_SR_F8_F32_Pat<SDPatternOperator node, bits<2> index, VOP3_Pseudo inst> : GCNPat< (i32 (node f32:$src0, i32:$src1, i32:$old, index)), (inst !if(index{1}, SRCMODS.DST_OP_SEL, 0), $src0, 0, $src1, @@ -827,21 +865,37 @@ class Cvt_SR_F8_ByteSel_Pat<SDPatternOperator node, VOP3_Pseudo inst, ValueType (inst $src0_modifiers, $src0, $src1_modifiers, $src1, $old, (as_i32timm $byte_sel)) >; +class Cvt_SR_F8_ByteSel_E5M3_Pat<SDPatternOperator node, VOP3_Pseudo inst, + ValueType SrcVT, int Clamp> : GCNPat< + (i32 (node (VOP3Mods SrcVT:$src0, i32:$src0_modifiers), (VOP3Mods i32:$src1, i32:$src1_modifiers), + i32:$old, timm:$byte_sel)), + (inst $src0_modifiers, $src0, $src1_modifiers, $src1, Clamp, $old, (as_i32timm $byte_sel)) +>; + let OtherPredicates = [HasFP8ConversionInsts] in { foreach Index = [0, -1] in { let True16Predicate = NotHasTrue16BitInsts in { - def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_fp8_f32, Index, V_CVT_PK_FP8_F32_e64>; + let OtherPredicates = [HasFP8ConversionInsts, NotHasFP8E5M3Insts] in + def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_fp8_f32, Index, V_CVT_PK_FP8_F32_e64>; def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_bf8_f32, Index, V_CVT_PK_BF8_F32_e64>; } let True16Predicate = UseFakeTrue16Insts in { def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_fp8_f32, Index, V_CVT_PK_FP8_F32_fake16_e64>; def : Cvt_PK_F8_F32_Pat<int_amdgcn_cvt_pk_bf8_f32, Index, V_CVT_PK_BF8_F32_fake16_e64>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in { + def : Cvt_PK_F8_F32_E5M3_Pat<int_amdgcn_cvt_pk_fp8_f32, Index, V_CVT_PK_FP8_F32_gfx1250_fake16_e64, DSTCLAMP.NONE>; + def : Cvt_PK_F8_F32_E5M3_Pat<int_amdgcn_cvt_pk_fp8_f32_e5m3, Index, V_CVT_PK_FP8_F32_gfx1250_fake16_e64, DSTCLAMP.ENABLE>; + } } } let True16Predicate = UseRealTrue16Insts in { defm : Cvt_PK_F8_F32_t16_Pat<int_amdgcn_cvt_pk_fp8_f32, V_CVT_PK_FP8_F32_t16_e64>; defm : Cvt_PK_F8_F32_t16_Pat<int_amdgcn_cvt_pk_bf8_f32, V_CVT_PK_BF8_F32_t16_e64>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in { + defm : Cvt_PK_F8_F32_E5M3_t16_Pat<int_amdgcn_cvt_pk_fp8_f32, V_CVT_PK_FP8_F32_gfx1250_t16_e64, DSTCLAMP.NONE>; + defm : Cvt_PK_F8_F32_E5M3_t16_Pat<int_amdgcn_cvt_pk_fp8_f32_e5m3, V_CVT_PK_FP8_F32_gfx1250_t16_e64, DSTCLAMP.ENABLE>; + } } let SubtargetPredicate = isGFX940Plus in { @@ -852,7 +906,12 @@ let SubtargetPredicate = isGFX940Plus in { } let SubtargetPredicate = isGFX12Plus in { - def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_fp8_f32, V_CVT_SR_FP8_F32_gfx12_e64, f32>; + let OtherPredicates = [HasFP8ConversionInsts, NotHasFP8E5M3Insts] in + def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_fp8_f32, V_CVT_SR_FP8_F32_gfx12_e64, f32>; + let OtherPredicates = [HasFP8ConversionInsts, HasFP8E5M3Insts] in { + def : Cvt_SR_F8_ByteSel_E5M3_Pat<int_amdgcn_cvt_sr_fp8_f32, V_CVT_SR_FP8_F32_gfx1250_e64, f32, DSTCLAMP.NONE>; + def : Cvt_SR_F8_ByteSel_E5M3_Pat<int_amdgcn_cvt_sr_fp8_f32_e5m3, V_CVT_SR_FP8_F32_gfx1250_e64, f32, DSTCLAMP.ENABLE>; + } def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_bf8_f32, V_CVT_SR_BF8_F32_gfx12_e64, f32>; } } @@ -885,6 +944,13 @@ def : GCNPat< (V_LSHL_ADD_U64_e64 VSrc_b64:$src0, VSrc_b32:$src1, VSrc_b64:$src2) >; +let SubtargetPredicate = HasAddMinMaxInsts in { +def : ThreeOp_i32_Pats<add, smax, V_ADD_MAX_I32_e64>; +def : ThreeOp_i32_Pats<add, umax, V_ADD_MAX_U32_e64>; +def : ThreeOp_i32_Pats<add, smin, V_ADD_MIN_I32_e64>; +def : ThreeOp_i32_Pats<add, umin, V_ADD_MIN_U32_e64>; +} + def : VOPBinOpClampPat<saddsat, V_ADD_I32_e64, i32>; def : VOPBinOpClampPat<ssubsat, V_SUB_I32_e64, i32>; @@ -987,6 +1053,14 @@ def VOP3_PERMLANE_VAR_Profile : VOP3_Profile<VOPProfile <[i32, i32, i32, untyped let HasExtDPP = 0; } +class VOP3_PERMLANE_NOOPSEL_Profile<VOPProfile P> : VOP3_Profile<P> { + let Ins64 = !con((ins VRegSrc_32:$src0, SSrc_b32:$src1), + !if(P.HasSrc2, (ins SSrc_b32:$src2), (ins))); + let HasClamp = 0; + let HasExtVOP3DPP = 0; + let HasExtDPP = 0; +} + def opsel_i1timm : SDNodeXForm<timm, [{ return CurDAG->getTargetConstant( N->getZExtValue() ? SISrcMods::OP_SEL_0 : SISrcMods::NONE, @@ -999,10 +1073,12 @@ class SrcAndDstSelToOpSelXForm<int modifier_idx, bit dest_sel> : SDNodeXForm<tim unsigned Val = N->getZExtValue(); unsigned New = 0; if (}] # modifier_idx # [{ == 0) { - New = (}] # dest_sel # [{ == 1) ? ((Val & 0x2) ? (SISrcMods::OP_SEL_0 | SISrcMods::DST_OP_SEL) : SISrcMods::DST_OP_SEL) - : ((Val & 0x2) ? SISrcMods::OP_SEL_0 : SISrcMods::NONE); - } else if (}] # modifier_idx # [{== 1 || }] # modifier_idx # [{ == 2) { - New = (Val & 0x1) ? SISrcMods::OP_SEL_0 : SISrcMods::NONE; + New = (}] # dest_sel # [{ == 1) ? ((Val & 0x1) ? (SISrcMods::OP_SEL_0 | SISrcMods::DST_OP_SEL) : SISrcMods::DST_OP_SEL) + : ((Val & 0x1) ? SISrcMods::OP_SEL_0 : SISrcMods::NONE); + } else if (}] # modifier_idx # [{== 1) { + New = (Val & 0x2) ? SISrcMods::OP_SEL_0 : SISrcMods::NONE; + } if (}] # modifier_idx # [{== 2) { + New = (Val & 0x1) ? SISrcMods::OP_SEL_0 : SISrcMods::NONE; } return CurDAG->getTargetConstant(New, SDLoc(N), MVT::i32); }]>; @@ -1068,6 +1144,18 @@ class PermlaneVarPat<SDPatternOperator permlane, VGPR_32:$src1, VGPR_32:$vdst_in) >; +class PermlaneNoDppPat3Src<SDPatternOperator permlane, + Instruction inst> : GCNPat< + (permlane i32:$src0, i32:$src1, i32:$src2), + (inst VGPR_32:$src0, SCSrc_b32:$src1, SCSrc_b32:$src2) +>; + +class PermlaneNoDppPat2Src<SDPatternOperator permlane, + Instruction inst> : GCNPat< + (permlane i32:$src0, i32:$src1), + (inst VGPR_32:$src0, SCSrc_b32:$src1) +>; + class VOP3_BITOP3_Profile<VOPProfile pfl, VOP3Features f> : VOP3_Profile<pfl, f> { let HasClamp = 0; let HasOMod = 0; @@ -1454,6 +1542,20 @@ let SubtargetPredicate = isGFX12Plus in { } // End SubtargetPredicate = isGFX12Plus +let SubtargetPredicate = isGFX1250Plus, WaveSizePredicate = isWave32 in { + defm V_PERMLANE_BCAST_B32 : VOP3Inst<"v_permlane_bcast_b32", VOP3_PERMLANE_NOOPSEL_Profile<VOP_I32_I32_I32_I32>>; + defm V_PERMLANE_UP_B32 : VOP3Inst<"v_permlane_up_b32", VOP3_PERMLANE_NOOPSEL_Profile<VOP_I32_I32_I32_I32>>; + defm V_PERMLANE_DOWN_B32 : VOP3Inst<"v_permlane_down_b32", VOP3_PERMLANE_NOOPSEL_Profile<VOP_I32_I32_I32_I32>>; + defm V_PERMLANE_XOR_B32 : VOP3Inst<"v_permlane_xor_b32", VOP3_PERMLANE_NOOPSEL_Profile<VOP_I32_I32_I32_I32>>; + defm V_PERMLANE_IDX_GEN_B32 : VOP3Inst<"v_permlane_idx_gen_b32", VOP3_PERMLANE_NOOPSEL_Profile<VOP_I32_I32_I32>>; + + def : PermlaneNoDppPat3Src<int_amdgcn_permlane_bcast, V_PERMLANE_BCAST_B32_e64>; + def : PermlaneNoDppPat3Src<int_amdgcn_permlane_up, V_PERMLANE_UP_B32_e64>; + def : PermlaneNoDppPat3Src<int_amdgcn_permlane_down, V_PERMLANE_DOWN_B32_e64>; + def : PermlaneNoDppPat3Src<int_amdgcn_permlane_xor, V_PERMLANE_XOR_B32_e64>; + def : PermlaneNoDppPat2Src<int_amdgcn_permlane_idx_gen, V_PERMLANE_IDX_GEN_B32_e64>; +} // End SubtargetPredicate = isGFX1250Plus, WaveSizePredicate = isWave32 + let HasClamp = 0, HasModifiers = 1 in { def BitOp3_B16_Profile : VOP3_BITOP3_Profile<VOPProfile <[i16, i16, i16, i16, i32]>, VOP3_OPSEL>; def BitOp3_B16_t16_Profile : VOP3_Profile_True16<BitOp3_B16_Profile>; @@ -1596,6 +1698,7 @@ def bf16_fpround : PatFrag <(ops node:$src0), (fpround $src0), [{ return true; let SubtargetPredicate = HasBF16ConversionInsts in { let ReadsModeReg = 0 in { defm V_CVT_PK_BF16_F32 : VOP3Inst<"v_cvt_pk_bf16_f32", VOP3_Profile<VOP_V2BF16_F32_F32>>; + defm V_CVT_SR_PK_BF16_F32 : VOP3Inst<"v_cvt_sr_pk_bf16_f32", VOP3_Profile<VOP_V2BF16_F32_F32_I32>, int_amdgcn_cvt_sr_pk_bf16_f32>; } def : GCNPat<(v2bf16 (bf16_fpround v2f32:$src)), (V_CVT_PK_BF16_F32_e64 0, (EXTRACT_SUBREG VReg_64:$src, sub0), 0, (EXTRACT_SUBREG VReg_64:$src, sub1))>; @@ -1606,6 +1709,141 @@ let SubtargetPredicate = HasBF16ConversionInsts in { (V_CVT_PK_BF16_F32_e64 $src0_modifiers, $src0, 0, (f32 (IMPLICIT_DEF)))>; } +class VOP3_CVT_SCALE_PK_F16_F864_Profile<VOPProfile P> : VOP3_CVT_SCALEF32_PK_F864_Profile<P> { + let Src0RC64 = getVOP3VRegSrcForVT<Src0VT>.ret; + let Ins64 = !con(getIns64<Src0RC64, Src1RC64, Src2RC64, NumSrcArgs, + HasClamp, HasModifiers, HasSrc2Mods, + HasOMod, Src0Mod, Src1Mod, Src2Mod>.ret, + (ins ScaleSel:$scale_sel)); + let Asm64 = getAsmVOP3Base<NumSrcArgs, HasDst, HasClamp, + HasOpSel, HasOMod, IsVOP3P, HasNeg, HasSrc0Mods, HasSrc1Mods, + HasSrc2Mods, DstVT>.ret # "$scale_sel"; +} + +multiclass VOP3CvtScaleSelInst<string OpName, VOPProfile P, SDPatternOperator node> { + def _e64 : VOP3InstBase<OpName, VOP3_CVT_SCALE_PK_F16_F864_Profile<P>> { + let Pattern = [(set P.DstVT:$vdst, (node (P.Src0VT (VOP3Mods0 P.Src0VT:$src0)), i32:$src1, i32:$scale_sel))]; + } +} + +let HasExtVOP3DPP = 0, HasModifiers = 0 in { +def VOP3_V2I32_I32_I32_V2I32 : VOP3_Profile<VOPProfile<[v2i32, i32, i32, v2i32]>>; +def VOP3_V3I32_I32_I64_V2I32 : VOP3_Profile<VOPProfile<[v3i32, i32, i64, v2i32]>>; +def VOP3_V4I32_I64_I64_V2I32 : VOP3_Profile<VOPProfile<[v4i32, i64, i64, v2i32]>>; +} + +let Src0RC64 = VSrc_NoInline_v2f16 in { +def VOP3_CVT_PK_F8_F16_Profile : VOP3_Profile<VOP_I16_V2F16>; +def VOP3_CVT_PK_F8_F16_True16_Profile : VOP3_Profile_True16<VOP3_CVT_PK_F8_F16_Profile>; +def VOP3_CVT_PK_F8_F16_Fake16_Profile : VOP3_Profile_Fake16<VOP3_CVT_PK_F8_F16_Profile>; +} + +let ReadsModeReg = 0, IsPacked = 0, SubtargetPredicate = isGFX125xOnly in { + defm V_CVT_PK_FP8_F16_gfx1250 : VOP3Inst_t16_with_profiles<"v_cvt_pk_fp8_f16_gfx1250", + VOP3_CVT_PK_F8_F16_Profile, + VOP3_CVT_PK_F8_F16_True16_Profile, + VOP3_CVT_PK_F8_F16_Fake16_Profile, + int_amdgcn_cvt_pk_fp8_f16>; + defm V_CVT_PK_BF8_F16_gfx1250 : VOP3Inst_t16_with_profiles<"v_cvt_pk_bf8_f16_gfx1250", + VOP3_CVT_PK_F8_F16_Profile, + VOP3_CVT_PK_F8_F16_True16_Profile, + VOP3_CVT_PK_F8_F16_Fake16_Profile, + int_amdgcn_cvt_pk_bf8_f16>; +} + +let HasClamp = 0, HasOpSel = 1 in { +def VOP3_CVT_SR_F8_F16_Profile : VOP3_CVT_SR_F8_ByteSel_Profile<f16>; +def VOP3_CVT_SR_F8_F16_True16_Profile : VOP3_Profile_True16<VOP3_CVT_SR_F8_F16_Profile>; +def VOP3_CVT_SR_F8_F16_Fake16_Profile : VOP3_Profile_Fake16<VOP3_CVT_SR_F8_F16_Profile>; +} + +let SubtargetPredicate = isGFX1250Plus in { + let ReadsModeReg = 0 in { + defm V_CVT_SR_PK_F16_F32 : VOP3Inst<"v_cvt_sr_pk_f16_f32", VOP3_Profile<VOP_V2F16_F32_F32_I32>, int_amdgcn_cvt_sr_pk_f16_f32>; + + // These instructions have non-standard use of op_sel. They are using bits 2 and 3 of opsel + // to select a byte in the vdst. Bits 0 and 1 are unused. + let Constraints = "$vdst = $vdst_in", DisableEncoding = "$vdst_in" in { + defm V_CVT_SR_FP8_F16 : VOP3Inst_t16_with_profiles<"v_cvt_sr_fp8_f16", VOP3_CVT_SR_F8_F16_Profile, + VOP3_CVT_SR_F8_F16_True16_Profile, VOP3_CVT_SR_F8_F16_Fake16_Profile>; + defm V_CVT_SR_BF8_F16 : VOP3Inst_t16_with_profiles<"v_cvt_sr_bf8_f16", VOP3_CVT_SR_F8_F16_Profile, + VOP3_CVT_SR_F8_F16_True16_Profile, VOP3_CVT_SR_F8_F16_Fake16_Profile>; + } + + let Constraints = "@earlyclobber $vdst" in { + defm V_CVT_SCALE_PK8_F16_FP8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f16_fp8", VOP_V8F16_V2I32_I32, int_amdgcn_cvt_scale_pk8_f16_fp8>; + defm V_CVT_SCALE_PK8_BF16_FP8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_bf16_fp8", VOP_V8BF16_V2I32_I32, int_amdgcn_cvt_scale_pk8_bf16_fp8>; + defm V_CVT_SCALE_PK8_F16_BF8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f16_bf8", VOP_V8F16_V2I32_I32, int_amdgcn_cvt_scale_pk8_f16_bf8>; + defm V_CVT_SCALE_PK8_BF16_BF8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_bf16_bf8", VOP_V8BF16_V2I32_I32, int_amdgcn_cvt_scale_pk8_bf16_bf8>; + defm V_CVT_SCALE_PK8_F32_FP8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_fp8", VOP_V8F32_V2I32_I32, int_amdgcn_cvt_scale_pk8_f32_fp8>; + defm V_CVT_SCALE_PK8_F32_BF8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_bf8", VOP_V8F32_V2I32_I32, int_amdgcn_cvt_scale_pk8_f32_bf8>; + defm V_CVT_SCALE_PK16_F16_FP6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_f16_fp6", VOP_V16F16_V3I32_I32, int_amdgcn_cvt_scale_pk16_f16_fp6>; + defm V_CVT_SCALE_PK16_BF16_FP6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_bf16_fp6", VOP_V16BF16_V3I32_I32, int_amdgcn_cvt_scale_pk16_bf16_fp6>; + defm V_CVT_SCALE_PK16_F16_BF6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_f16_bf6", VOP_V16F16_V3I32_I32, int_amdgcn_cvt_scale_pk16_f16_bf6>; + defm V_CVT_SCALE_PK16_BF16_BF6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_bf16_bf6", VOP_V16BF16_V3I32_I32, int_amdgcn_cvt_scale_pk16_bf16_bf6>; + defm V_CVT_SCALE_PK16_F32_FP6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_f32_fp6", VOP_V16F32_V3I32_I32, int_amdgcn_cvt_scale_pk16_f32_fp6>; + defm V_CVT_SCALE_PK16_F32_BF6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_f32_bf6", VOP_V16F32_V3I32_I32, int_amdgcn_cvt_scale_pk16_f32_bf6>; + } // End Constraints = "@earlyclobber $vdst" + + defm V_CVT_SCALE_PK8_F16_FP4 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f16_fp4", VOP_V8F16_I32_I32, int_amdgcn_cvt_scale_pk8_f16_fp4>; + defm V_CVT_SCALE_PK8_BF16_FP4 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_bf16_fp4", VOP_V8BF16_I32_I32, int_amdgcn_cvt_scale_pk8_bf16_fp4>; + defm V_CVT_SCALE_PK8_F32_FP4 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_fp4", VOP_V8F32_I32_I32, int_amdgcn_cvt_scale_pk8_f32_fp4>; + } // End ReadsModeReg = 0 + + let Constraints = "@earlyclobber $vdst" in { + let WaveSizePredicate = isWave32 in { + defm V_CVT_SCALEF32_PK8_FP8_BF16 : VOP3Inst<"v_cvt_scalef32_pk8_fp8_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8BF16_F32>, int_amdgcn_cvt_scalef32_pk8_fp8_bf16>; + defm V_CVT_SCALEF32_PK8_BF8_BF16 : VOP3Inst<"v_cvt_scalef32_pk8_bf8_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8BF16_F32>, int_amdgcn_cvt_scalef32_pk8_bf8_bf16>; + defm V_CVT_SCALEF32_PK8_FP8_F16 : VOP3Inst<"v_cvt_scalef32_pk8_fp8_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F16_F32>, int_amdgcn_cvt_scalef32_pk8_fp8_f16>; + defm V_CVT_SCALEF32_PK8_BF8_F16 : VOP3Inst<"v_cvt_scalef32_pk8_bf8_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F16_F32>, int_amdgcn_cvt_scalef32_pk8_bf8_f16>; + defm V_CVT_SCALEF32_PK8_FP8_F32 : VOP3Inst<"v_cvt_scalef32_pk8_fp8_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F32_F32>, int_amdgcn_cvt_scalef32_pk8_fp8_f32>; + defm V_CVT_SCALEF32_PK8_BF8_F32 : VOP3Inst<"v_cvt_scalef32_pk8_bf8_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F32_F32>, int_amdgcn_cvt_scalef32_pk8_bf8_f32>; + defm V_CVT_SCALEF32_PK8_FP4_F32 : VOP3Inst<"v_cvt_scalef32_pk8_fp4_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8F32_F32>, int_amdgcn_cvt_scalef32_pk8_fp4_f32>; + defm V_CVT_SCALEF32_PK8_FP4_F16 : VOP3Inst<"v_cvt_scalef32_pk8_fp4_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8F16_F32>, int_amdgcn_cvt_scalef32_pk8_fp4_f16>; + defm V_CVT_SCALEF32_PK8_FP4_BF16 : VOP3Inst<"v_cvt_scalef32_pk8_fp4_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8BF16_F32>, int_amdgcn_cvt_scalef32_pk8_fp4_bf16>; + } // End WaveSizePredicate = isWave32 + defm V_CVT_SCALEF32_PK16_FP6_F32 : VOP3Inst<"v_cvt_scalef32_pk16_fp6_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F32_F32>, int_amdgcn_cvt_scalef32_pk16_fp6_f32>; + defm V_CVT_SCALEF32_PK16_BF6_F32 : VOP3Inst<"v_cvt_scalef32_pk16_bf6_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F32_F32>, int_amdgcn_cvt_scalef32_pk16_bf6_f32>; + defm V_CVT_SCALEF32_PK16_FP6_F16 : VOP3Inst<"v_cvt_scalef32_pk16_fp6_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F16_F32>, int_amdgcn_cvt_scalef32_pk16_fp6_f16>; + defm V_CVT_SCALEF32_PK16_BF6_F16 : VOP3Inst<"v_cvt_scalef32_pk16_bf6_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F16_F32>, int_amdgcn_cvt_scalef32_pk16_bf6_f16>; + defm V_CVT_SCALEF32_PK16_FP6_BF16 : VOP3Inst<"v_cvt_scalef32_pk16_fp6_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16BF16_F32>, int_amdgcn_cvt_scalef32_pk16_fp6_bf16>; + defm V_CVT_SCALEF32_PK16_BF6_BF16 : VOP3Inst<"v_cvt_scalef32_pk16_bf6_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16BF16_F32>, int_amdgcn_cvt_scalef32_pk16_bf6_bf16>; + + let WaveSizePredicate = isWave32 in { + defm V_CVT_SCALEF32_SR_PK8_FP8_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp8_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp8_bf16>; + defm V_CVT_SCALEF32_SR_PK8_BF8_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_bf8_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_bf8_bf16>; + defm V_CVT_SCALEF32_SR_PK8_FP8_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp8_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp8_f16>; + defm V_CVT_SCALEF32_SR_PK8_BF8_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_bf8_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_bf8_f16>; + defm V_CVT_SCALEF32_SR_PK8_FP8_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp8_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp8_f32>; + defm V_CVT_SCALEF32_SR_PK8_BF8_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk8_bf8_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_bf8_f32>; + defm V_CVT_SCALEF32_SR_PK8_FP4_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp4_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp4_f32>; + defm V_CVT_SCALEF32_SR_PK8_FP4_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp4_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp4_f16>; + defm V_CVT_SCALEF32_SR_PK8_FP4_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp4_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp4_bf16>; + } // End WaveSizePredicate = isWave32 + defm V_CVT_SCALEF32_SR_PK16_BF6_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk16_bf6_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_bf6_bf16>; + defm V_CVT_SCALEF32_SR_PK16_BF6_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk16_bf6_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_bf6_f16>; + defm V_CVT_SCALEF32_SR_PK16_BF6_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk16_bf6_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_bf6_f32>; + defm V_CVT_SCALEF32_SR_PK16_FP6_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk16_fp6_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_fp6_bf16>; + defm V_CVT_SCALEF32_SR_PK16_FP6_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk16_fp6_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_fp6_f16>; + defm V_CVT_SCALEF32_SR_PK16_FP6_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk16_fp6_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_fp6_f32>; + } // End Constraints = "@earlyclobber $vdst" + + let True16Predicate = UseRealTrue16Insts in { + def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_fp8_f16, V_CVT_SR_FP8_F16_t16_e64, f16>; + def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_bf8_f16, V_CVT_SR_BF8_F16_t16_e64, f16>; + } + let True16Predicate = UseFakeTrue16Insts in { + def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_fp8_f16, V_CVT_SR_FP8_F16_fake16_e64, f16>; + def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_bf8_f16, V_CVT_SR_BF8_F16_fake16_e64, f16>; + } +} // End SubtargetPredicate = isGFX1250Plus + +let SubtargetPredicate = HasTensorCvtLutInsts in { + defm V_PERM_PK16_B4_U4 : VOP3Inst<"v_perm_pk16_b4_u4", VOP3_V2I32_I32_I32_V2I32, int_amdgcn_perm_pk16_b4_u4>; + defm V_PERM_PK16_B6_U4 : VOP3Inst<"v_perm_pk16_b6_u4", VOP3_V3I32_I32_I64_V2I32, int_amdgcn_perm_pk16_b6_u4>; + defm V_PERM_PK16_B8_U4 : VOP3Inst<"v_perm_pk16_b8_u4", VOP3_V4I32_I64_I64_V2I32, int_amdgcn_perm_pk16_b8_u4>; +} // End SubtargetPredicate = HasTensorCvtLutInsts + class Cvt_Scale_Sr_F32ToBF16F16_Pat<SDPatternOperator node, VOP3_Pseudo inst, ValueType DstTy> : GCNPat< (DstTy (node DstTy:$vdst_in, f32:$src0, i32:$src1, timm:$word_sel)), (inst (DstSelToOpSelXForm $word_sel), $src0, 0, $src1, VGPR_32:$vdst_in) @@ -1821,11 +2059,15 @@ defm V_MIN_U64 : VOP3Only_Realtriple_gfx1250<0x318>; defm V_MAX_U64 : VOP3Only_Realtriple_gfx1250<0x319>; defm V_MIN_I64 : VOP3Only_Realtriple_gfx1250<0x31a>; defm V_MAX_I64 : VOP3Only_Realtriple_gfx1250<0x31b>; - -defm V_CVT_PK_FP8_F32 : VOP3Only_Realtriple_t16_and_fake16_gfx12<0x369, "v_cvt_pk_fp8_f32">; -defm V_CVT_PK_BF8_F32 : VOP3Only_Realtriple_t16_and_fake16_gfx12<0x36a, "v_cvt_pk_bf8_f32">; -defm V_CVT_SR_FP8_F32_gfx12 : VOP3_Realtriple_with_name_gfx12<0x36b, "V_CVT_SR_FP8_F32_gfx12", "v_cvt_sr_fp8_f32" >; -defm V_CVT_SR_BF8_F32_gfx12 : VOP3_Realtriple_with_name_gfx12<0x36c, "V_CVT_SR_BF8_F32_gfx12", "v_cvt_sr_bf8_f32">; +defm V_ADD_MAX_I32 : VOP3Only_Realtriple_gfx1250<0x25e>; +defm V_ADD_MAX_U32 : VOP3Only_Realtriple_gfx1250<0x25f>; +defm V_ADD_MIN_I32 : VOP3Only_Realtriple_gfx1250<0x260>; +defm V_ADD_MIN_U32 : VOP3Only_Realtriple_gfx1250<0x261>; +defm V_PERMLANE_BCAST_B32 : VOP3Only_Real_Base_gfx12<0x270>; +defm V_PERMLANE_UP_B32 : VOP3Only_Real_Base_gfx12<0x271>; +defm V_PERMLANE_DOWN_B32 : VOP3Only_Real_Base_gfx12<0x272>; +defm V_PERMLANE_XOR_B32 : VOP3Only_Real_Base_gfx12<0x273>; +defm V_PERMLANE_IDX_GEN_B32 : VOP3Only_Real_Base_gfx12<0x314>; //===----------------------------------------------------------------------===// // GFX11, GFX12 @@ -1987,14 +2229,78 @@ defm V_AND_B16 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12<0x36 defm V_OR_B16 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12<0x363, "v_or_b16">; defm V_XOR_B16 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12<0x364, "v_xor_b16">; +defm V_CVT_PK_FP8_F32 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12_not_gfx1250<0x369, "v_cvt_pk_fp8_f32">; +defm V_CVT_PK_FP8_F32_gfx1250 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x369, "v_cvt_pk_fp8_f32">; +defm V_CVT_PK_BF8_F32 : VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12<0x36a, "v_cvt_pk_bf8_f32">; +defm V_CVT_SR_FP8_F32_gfx12 : VOP3_Realtriple_with_name_gfx11_gfx12_not_gfx1250<0x36b, "V_CVT_SR_FP8_F32_gfx12", "v_cvt_sr_fp8_f32">; +defm V_CVT_SR_FP8_F32_gfx1250 : VOP3Only_Realtriple_with_name_gfx1250<0x36b, "V_CVT_SR_FP8_F32_gfx1250", "v_cvt_sr_fp8_f32">; +defm V_CVT_SR_BF8_F32_gfx12 : VOP3_Realtriple_with_name_gfx11_gfx12<0x36c, "V_CVT_SR_BF8_F32_gfx12", "v_cvt_sr_bf8_f32">; + let AssemblerPredicate = isGFX11Plus in { def : AMDGPUMnemonicAlias<"v_add3_nc_u32", "v_add3_u32">; def : AMDGPUMnemonicAlias<"v_xor_add_u32", "v_xad_u32">; } // These instructions differ from GFX12 variant by supporting DPP: +defm V_PERM_PK16_B4_U4 : VOP3Only_Real_Base_gfx1250<0x23f>; +defm V_PERM_PK16_B6_U4 : VOP3Only_Real_Base_gfx1250<0x242>; +defm V_PERM_PK16_B8_U4 : VOP3Only_Real_Base_gfx1250<0x243>; defm V_LSHL_ADD_U64 : VOP3Only_Realtriple_gfx1250<0x252>; +defm V_ASHR_PK_I8_I32 : VOP3Only_Realtriple_gfx1250<0x290>; +defm V_ASHR_PK_U8_I32 : VOP3Only_Realtriple_gfx1250<0x291>; +defm V_CVT_SCALE_PK8_F16_FP4 : VOP3Only_ScaleSel_Real_gfx1250<0x29f>; +defm V_CVT_SCALE_PK8_BF16_FP4 : VOP3Only_ScaleSel_Real_gfx1250<0x2a0>; +defm V_CVT_SCALE_PK8_F32_FP4 : VOP3Only_ScaleSel_Real_gfx1250<0x2a1>; +defm V_CVT_SCALE_PK8_F16_FP8 : VOP3Only_ScaleSel_Real_gfx1250<0x2a8>; +defm V_CVT_SCALE_PK8_BF16_FP8 : VOP3Only_ScaleSel_Real_gfx1250<0x2a9>; +defm V_CVT_SCALE_PK8_F32_FP8 : VOP3Only_ScaleSel_Real_gfx1250<0x2aa>; +defm V_CVT_SCALE_PK8_F16_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ab>; +defm V_CVT_SCALE_PK8_BF16_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ac>; +defm V_CVT_SCALE_PK8_F32_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ad>; +defm V_CVT_SCALEF32_PK8_FP4_F32 : VOP3Only_Real_Base_gfx1250<0x2b0>; +defm V_CVT_SCALEF32_PK8_FP4_F16 : VOP3Only_Real_Base_gfx1250<0x2b3>; +defm V_CVT_SCALEF32_PK8_FP8_BF16 : VOP3Only_Real_Base_gfx1250<0x2b4>; +defm V_CVT_SCALEF32_PK8_BF8_BF16 : VOP3Only_Real_Base_gfx1250<0x2b5>; +defm V_CVT_SCALEF32_PK8_FP4_BF16 : VOP3Only_Real_Base_gfx1250<0x2b8>; +defm V_CVT_SCALEF32_PK8_FP8_F32 : VOP3Only_Real_Base_gfx1250<0x2c3>; +defm V_CVT_SCALEF32_PK8_FP8_F16 : VOP3Only_Real_Base_gfx1250<0x2c4>; +defm V_CVT_SCALEF32_PK8_BF8_F32 : VOP3Only_Real_Base_gfx1250<0x2c5>; +defm V_CVT_SCALEF32_PK8_BF8_F16 : VOP3Only_Real_Base_gfx1250<0x2c6>; +defm V_CVT_SCALE_PK16_F16_FP6 : VOP3Only_ScaleSel_Real_gfx1250<0x2c7>; +defm V_CVT_SCALE_PK16_BF16_FP6 : VOP3Only_ScaleSel_Real_gfx1250<0x2c8>; +defm V_CVT_SCALE_PK16_F32_FP6 : VOP3Only_ScaleSel_Real_gfx1250<0x2c9>; +defm V_CVT_SCALE_PK16_F16_BF6 : VOP3Only_ScaleSel_Real_gfx1250<0x2ca>; +defm V_CVT_SCALE_PK16_BF16_BF6 : VOP3Only_ScaleSel_Real_gfx1250<0x2cb>; +defm V_CVT_SCALE_PK16_F32_BF6 : VOP3Only_ScaleSel_Real_gfx1250<0x2cc>; +defm V_CVT_SCALEF32_PK16_FP6_F32 : VOP3Only_Real_Base_gfx1250<0x2cd>; +defm V_CVT_SCALEF32_PK16_BF6_F32 : VOP3Only_Real_Base_gfx1250<0x2ce>; +defm V_CVT_SCALEF32_PK16_FP6_F16 : VOP3Only_Real_Base_gfx1250<0x2cf>; +defm V_CVT_SCALEF32_PK16_BF6_F16 : VOP3Only_Real_Base_gfx1250<0x2d0>; +defm V_CVT_SCALEF32_PK16_FP6_BF16 : VOP3Only_Real_Base_gfx1250<0x2d1>; +defm V_CVT_SCALEF32_PK16_BF6_BF16 : VOP3Only_Real_Base_gfx1250<0x2d2>; +defm V_CVT_SCALEF32_SR_PK16_FP6_F32 : VOP3Only_Real_Base_gfx1250<0x2d3>; +defm V_CVT_SCALEF32_SR_PK16_BF6_F32 : VOP3Only_Real_Base_gfx1250<0x2d4>; +defm V_CVT_SCALEF32_SR_PK16_FP6_F16 : VOP3Only_Real_Base_gfx1250<0x2d5>; +defm V_CVT_SCALEF32_SR_PK16_BF6_F16 : VOP3Only_Real_Base_gfx1250<0x2d6>; +defm V_CVT_SCALEF32_SR_PK16_FP6_BF16 : VOP3Only_Real_Base_gfx1250<0x2d7>; +defm V_CVT_SCALEF32_SR_PK16_BF6_BF16 : VOP3Only_Real_Base_gfx1250<0x2d8>; +defm V_CVT_SCALEF32_SR_PK8_FP4_F32 : VOP3Only_Real_Base_gfx1250<0x297>; +defm V_CVT_SCALEF32_SR_PK8_FP8_F32 : VOP3Only_Real_Base_gfx1250<0x298>; +defm V_CVT_SCALEF32_SR_PK8_BF8_F32 : VOP3Only_Real_Base_gfx1250<0x299>; +defm V_CVT_SCALEF32_SR_PK8_FP4_F16 : VOP3Only_Real_Base_gfx1250<0x2b9>; +defm V_CVT_SCALEF32_SR_PK8_FP4_BF16 : VOP3Only_Real_Base_gfx1250<0x2bc>; +defm V_CVT_SCALEF32_SR_PK8_FP8_F16 : VOP3Only_Real_Base_gfx1250<0x2bf>; +defm V_CVT_SCALEF32_SR_PK8_FP8_BF16 : VOP3Only_Real_Base_gfx1250<0x2c0>; +defm V_CVT_SCALEF32_SR_PK8_BF8_F16 : VOP3Only_Real_Base_gfx1250<0x2c1>; +defm V_CVT_SCALEF32_SR_PK8_BF8_BF16 : VOP3Only_Real_Base_gfx1250<0x2c2>; defm V_CVT_PK_BF16_F32 : VOP3Only_Realtriple_gfx1250<0x36d>; +defm V_CVT_SR_PK_BF16_F32 : VOP3Only_Realtriple_gfx1250<0x36e>; +defm V_CVT_PK_F16_F32 : VOP3Only_Realtriple_gfx1250<0x36f>; +defm V_CVT_SR_PK_F16_F32 : VOP3Only_Realtriple_gfx1250<0x370>; +defm V_CVT_PK_FP8_F16_gfx1250 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x372, "v_cvt_pk_fp8_f16">; +defm V_CVT_PK_BF8_F16_gfx1250 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x373, "v_cvt_pk_bf8_f16">; +defm V_CVT_SR_FP8_F16 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x374>; +defm V_CVT_SR_BF8_F16 : VOP3Only_Realtriple_t16_and_fake16_gfx1250<0x375>; //===----------------------------------------------------------------------===// // GFX10. diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td index badbba9..f027ab0 100644 --- a/llvm/lib/Target/AMDGPU/VOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td @@ -414,6 +414,13 @@ class VOP3a_BITOP3_gfx12<bits<10> op, VOPProfile p> : VOP3e_gfx11_gfx12<op, p> { let Inst{14} = !if(p.HasOpSel, src0_modifiers{3}, 0); } +class VOP3a_ScaleSel_gfx1250<bits<10> op, VOPProfile p> : VOP3e_gfx11_gfx12<op, p> { + bits<3> scale_sel; + + let Inst{13-11} = scale_sel; + let Inst{14} = 0; +} + class VOP3Interp_gfx10<bits<10> op, VOPProfile p> : VOP3e_gfx10<op, p> { bits<6> attr; bits<2> attrchan; @@ -2010,6 +2017,30 @@ multiclass VOP3_BITOP3_Real_Base<GFXGen Gen, bits<10> op, string asmName> { } } +multiclass VOP3Only_ScaleSel_Real_gfx1250<bits<10> op> { + defvar ps = !cast<VOP_Pseudo>(NAME#"_e64"); + def _e64_gfx1250 : + VOP3_Real_Gen<ps, GFX1250Gen>, + VOP3a_ScaleSel_gfx1250<op, ps.Pfl>; +} + +multiclass VOP3Only_Realtriple_t16_gfx11_gfx12_not_gfx1250<bits<10> op, string asmName, string opName = NAME, + string pseudo_mnemonic = "", bit isSingle = 0> : + VOP3_Realtriple_with_name<GFX11Gen, op, opName, asmName, pseudo_mnemonic, isSingle>, + VOP3_Realtriple_with_name<GFX12Not12_50Gen, op, opName, asmName, pseudo_mnemonic, isSingle>; + +multiclass VOP3Only_Realtriple_t16_and_fake16_gfx11_gfx12_not_gfx1250<bits<10> op, string asmName, + string opName = NAME, string pseudo_mnemonic = ""> { + defm _t16 : VOP3Only_Realtriple_t16_gfx11_gfx12_not_gfx1250<op, asmName, opName#"_t16", pseudo_mnemonic, 1>; + defm _fake16 : VOP3Only_Realtriple_t16_gfx11_gfx12_not_gfx1250<op, asmName, opName#"_fake16", pseudo_mnemonic, 1>; +} + +multiclass VOP3_Realtriple_with_name_gfx11_gfx12_not_gfx1250<bits<10> op, string opName, + string asmName, string pseudo_mnemonic = "", + bit isSingle = 0> : + VOP3_Realtriple_with_name<GFX11Gen, op, opName, asmName, pseudo_mnemonic, isSingle>, + VOP3_Realtriple_with_name<GFX12Not12_50Gen, op, opName, asmName, pseudo_mnemonic, isSingle>; + //===----------------------------------------------------------------------===// // VOP3 GFX11 //===----------------------------------------------------------------------===// @@ -2071,6 +2102,15 @@ multiclass VOP3Only_Real_Base_gfx1250<bits<10> op> : multiclass VOP3Only_Realtriple_gfx1250<bits<10> op, bit isSingle = 0> : VOP3_Realtriple<GFX1250Gen, op, isSingle>; +multiclass VOP3Only_Realtriple_with_name_gfx1250<bits<10> op, string opName, + string asmName, string pseudo_mnemonic = "", + bit isSingle = 0> : + VOP3_Realtriple_with_name<GFX1250Gen, op, opName, asmName, pseudo_mnemonic, isSingle>; + +multiclass VOP3Only_Realtriple_t16_gfx1250<bits<10> op, string asmName = !cast<VOP3_Pseudo>(NAME#"_e64").Mnemonic, + string opName = NAME, string pseudo_mnemonic = "", bit isSingle = 0> : + VOP3Only_Realtriple_with_name_gfx1250<op, opName, asmName, pseudo_mnemonic, isSingle>; + multiclass VOP3_Realtriple_t16_gfx12<bits<10> op, string asmName, string opName = NAME, string pseudo_mnemonic = "", bit isSingle = 0> : VOP3_Realtriple_with_name<GFX12Gen, op, opName, asmName, pseudo_mnemonic, isSingle>; @@ -2091,6 +2131,13 @@ multiclass VOP3Only_Realtriple_t16_and_fake16_gfx12<bits<10> op, string asmName, defm _fake16 : VOP3Only_Realtriple_t16_gfx12<op, asmName, opName#"_fake16", pseudo_mnemonic>; } +multiclass VOP3Only_Realtriple_t16_and_fake16_gfx1250<bits<10> op, + string asmName = !cast<VOP3_Pseudo>(NAME#"_e64").Mnemonic, + string opName = NAME, string pseudo_mnemonic = ""> { + defm _t16 : VOP3Only_Realtriple_t16_gfx1250<op, asmName, opName#"_t16", pseudo_mnemonic>; + defm _fake16 : VOP3Only_Realtriple_t16_gfx1250<op, asmName, opName#"_fake16", pseudo_mnemonic>; +} + multiclass VOP3be_Real_with_name_gfx12<bits<10> op, string opName, string asmName, bit isSingle = 0> { defvar ps = !cast<VOP3_Pseudo>(opName#"_e64"); diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 066b392..9366256 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -2423,6 +2423,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, CallingConv::ID CallConv = CLI.CallConv; bool doesNotRet = CLI.DoesNotReturn; bool isVarArg = CLI.IsVarArg; + const CallBase *CB = CLI.CB; MachineFunction &MF = DAG.getMachineFunction(); ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>(); @@ -2446,6 +2447,10 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, !Subtarget->noBTIAtReturnTwice()) GuardWithBTI = AFI->branchTargetEnforcement(); + // Set type id for call site info. + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + CSInfo = MachineFunction::CallSiteInfo(*CB); + // Determine whether this is a non-secure function call. if (CLI.CB && CLI.CB->getAttributes().hasFnAttr("cmse_nonsecure_call")) isCmseNSCall = true; @@ -5516,18 +5521,6 @@ SDValue ARMTargetLowering::LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const { ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TrueVal); ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS); if (Op.getValueType().isInteger()) { - // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform - // into (OR (ASR lhs, N-1), 1), which requires less instructions for the - // supported types. - if (CC == ISD::SETGT && RHSC && RHSC->isAllOnes() && CTVal && CFVal && - CTVal->isOne() && CFVal->isAllOnes() && - LHS.getValueType() == TrueVal.getValueType()) { - EVT VT = LHS.getValueType(); - SDValue Shift = - DAG.getNode(ISD::SRA, dl, VT, LHS, - DAG.getConstant(VT.getSizeInBits() - 1, dl, VT)); - return DAG.getNode(ISD::OR, dl, VT, Shift, DAG.getConstant(1, dl, VT)); - } // Check for SMAX(lhs, 0) and SMIN(lhs, 0) patterns. // (SELECT_CC setgt, lhs, 0, lhs, 0) -> (BIC lhs, (SRA lhs, typesize-1)) diff --git a/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.cpp b/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.cpp index 146fc67..c221d22 100644 --- a/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.cpp +++ b/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.cpp @@ -1108,9 +1108,8 @@ std::optional<bool> ARMAsmBackend::evaluateFixup(const MCFragment &F, } void ARMAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (IsResolved && shouldForceRelocation(Fixup, Target)) IsResolved = false; maybeAddReloc(F, Fixup, Target, Value, IsResolved); @@ -1124,14 +1123,15 @@ void ARMAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, return; // Doesn't change encoding. const unsigned NumBytes = getFixupKindNumBytes(Kind); - unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // Used to point to big endian bytes. unsigned FullSizeBytes; if (Endian == llvm::endianness::big) { FullSizeBytes = getFixupKindContainerSizeBytes(Kind); - assert((Offset + FullSizeBytes) <= Data.size() && "Invalid fixup size!"); + assert(Fixup.getOffset() + FullSizeBytes <= F.getSize() && + "Invalid fixup size!"); assert(NumBytes <= FullSizeBytes && "Invalid fixup size!"); } @@ -1141,7 +1141,7 @@ void ARMAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, for (unsigned i = 0; i != NumBytes; ++i) { unsigned Idx = Endian == llvm::endianness::little ? i : (FullSizeBytes - 1 - i); - Data[Offset + Idx] |= uint8_t((Value >> (i * 8)) & 0xff); + Data[Idx] |= uint8_t((Value >> (i * 8)) & 0xff); } } diff --git a/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.h b/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.h index 07d2cf7..2844232 100644 --- a/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.h +++ b/llvm/lib/Target/ARM/MCTargetDesc/ARMAsmBackend.h @@ -40,8 +40,7 @@ public: std::optional<bool> evaluateFixup(const MCFragment &, MCFixup &, MCValue &, uint64_t &) override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; unsigned getRelaxedOpcode(unsigned Op, const MCSubtargetInfo &STI) const; diff --git a/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp b/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp index 128cc0b..05a7d03 100644 --- a/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp +++ b/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp @@ -368,9 +368,8 @@ AVRAsmBackend::createObjectTargetWriter() const { } void AVRAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { // AVR sets the fixup value to bypass the assembly time overflow with a // relocation. if (IsResolved) { @@ -397,14 +396,14 @@ void AVRAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Shift the value into position. Value <<= Info.TargetOffset; - unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. for (unsigned i = 0; i < NumBytes; ++i) { uint8_t mask = (((Value >> (i * 8)) & 0xff)); - Data[Offset + i] |= mask; + Data[i] |= mask; } } diff --git a/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.h b/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.h index 68c839e..9633669 100644 --- a/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.h +++ b/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.h @@ -38,8 +38,7 @@ public: createObjectTargetWriter() const override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; std::optional<MCFixupKind> getFixupKind(StringRef Name) const override; MCFixupKindInfo getFixupKindInfo(MCFixupKind Kind) const override; diff --git a/llvm/lib/Target/BPF/MCTargetDesc/BPFAsmBackend.cpp b/llvm/lib/Target/BPF/MCTargetDesc/BPFAsmBackend.cpp index dda8753..53933f9 100644 --- a/llvm/lib/Target/BPF/MCTargetDesc/BPFAsmBackend.cpp +++ b/llvm/lib/Target/BPF/MCTargetDesc/BPFAsmBackend.cpp @@ -27,8 +27,7 @@ public: ~BPFAsmBackend() override = default; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; std::unique_ptr<MCObjectTargetWriter> createObjectTargetWriter() const override; @@ -66,35 +65,32 @@ bool BPFAsmBackend::writeNopData(raw_ostream &OS, uint64_t Count, } void BPFAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { maybeAddReloc(F, Fixup, Target, Value, IsResolved); if (Fixup.getKind() == FK_SecRel_8) { // The Value is 0 for global variables, and the in-section offset // for static variables. Write to the immediate field of the inst. assert(Value <= UINT32_MAX); - support::endian::write<uint32_t>(&Data[Fixup.getOffset() + 4], - static_cast<uint32_t>(Value), + support::endian::write<uint32_t>(Data + 4, static_cast<uint32_t>(Value), Endian); } else if (Fixup.getKind() == FK_Data_4 && !Fixup.isPCRel()) { - support::endian::write<uint32_t>(&Data[Fixup.getOffset()], Value, Endian); + support::endian::write<uint32_t>(Data, Value, Endian); } else if (Fixup.getKind() == FK_Data_8) { - support::endian::write<uint64_t>(&Data[Fixup.getOffset()], Value, Endian); + support::endian::write<uint64_t>(Data, Value, Endian); } else if (Fixup.getKind() == FK_Data_4 && Fixup.isPCRel()) { Value = (uint32_t)((Value - 8) / 8); if (Endian == llvm::endianness::little) { - Data[Fixup.getOffset() + 1] = 0x10; - support::endian::write32le(&Data[Fixup.getOffset() + 4], Value); + Data[1] = 0x10; + support::endian::write32le(Data + 4, Value); } else { - Data[Fixup.getOffset() + 1] = 0x1; - support::endian::write32be(&Data[Fixup.getOffset() + 4], Value); + Data[1] = 0x1; + support::endian::write32be(Data + 4, Value); } } else if (Fixup.getKind() == BPF::FK_BPF_PCRel_4) { // The input Value represents the number of bytes. Value = (uint32_t)((Value - 8) / 8); - support::endian::write<uint32_t>(&Data[Fixup.getOffset() + 4], Value, - Endian); + support::endian::write<uint32_t>(Data + 4, Value, Endian); } else { assert(Fixup.getKind() == FK_Data_2 && Fixup.isPCRel()); @@ -103,8 +99,7 @@ void BPFAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, report_fatal_error("Branch target out of insn range"); Value = (uint16_t)((Value - 8) / 8); - support::endian::write<uint16_t>(&Data[Fixup.getOffset() + 2], Value, - Endian); + support::endian::write<uint16_t>(Data + 2, Value, Endian); } } diff --git a/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.cpp b/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.cpp index 694d9ea..6964998 100644 --- a/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.cpp +++ b/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.cpp @@ -197,9 +197,8 @@ std::optional<bool> CSKYAsmBackend::evaluateFixup(const MCFragment &F, } void CSKYAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (IsResolved && shouldForceRelocation(Fixup, Target)) IsResolved = false; maybeAddReloc(F, Fixup, Target, Value, IsResolved); @@ -217,10 +216,10 @@ void CSKYAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Shift the value into position. Value <<= Info.TargetOffset; - unsigned Offset = Fixup.getOffset(); unsigned NumBytes = alignTo(Info.TargetSize + Info.TargetOffset, 8) / 8; - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. @@ -228,14 +227,14 @@ void CSKYAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, bool IsInstFixup = (Kind >= FirstTargetFixupKind); if (IsLittleEndian && IsInstFixup && (NumBytes == 4)) { - Data[Offset + 0] |= uint8_t((Value >> 16) & 0xff); - Data[Offset + 1] |= uint8_t((Value >> 24) & 0xff); - Data[Offset + 2] |= uint8_t(Value & 0xff); - Data[Offset + 3] |= uint8_t((Value >> 8) & 0xff); + Data[0] |= uint8_t((Value >> 16) & 0xff); + Data[1] |= uint8_t((Value >> 24) & 0xff); + Data[2] |= uint8_t(Value & 0xff); + Data[3] |= uint8_t((Value >> 8) & 0xff); } else { for (unsigned I = 0; I != NumBytes; I++) { unsigned Idx = IsLittleEndian ? I : (NumBytes - 1 - I); - Data[Offset + Idx] |= uint8_t((Value >> (I * 8)) & 0xff); + Data[Idx] |= uint8_t((Value >> (I * 8)) & 0xff); } } } diff --git a/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.h b/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.h index 1c8516f..5d8826a 100644 --- a/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.h +++ b/llvm/lib/Target/CSKY/MCTargetDesc/CSKYAsmBackend.h @@ -25,8 +25,7 @@ public: std::optional<bool> evaluateFixup(const MCFragment &, MCFixup &, MCValue &, uint64_t &) override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; MCFixupKindInfo getFixupKindInfo(MCFixupKind Kind) const override; diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index ebdfcaa..a4f5086 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -17,7 +17,6 @@ #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Frontend/HLSL/RootSignatureMetadata.h" -#include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" @@ -111,14 +110,25 @@ analyzeModule(Module &M) { reportError(Ctx, "Root Element is not a metadata node."); continue; } - mcdxbc::RootSignatureDesc RSD; - if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2)) - RSD.Version = *Version; - else { + std::optional<uint32_t> V = extractMdIntValue(RSDefNode, 2); + if (!V.has_value()) { reportError(Ctx, "Invalid RSDefNode value, expected constant int"); continue; } + llvm::hlsl::rootsig::MetadataParser MDParser(RootElementListNode); + llvm::Expected<mcdxbc::RootSignatureDesc> RSDOrErr = + MDParser.ParseRootSignature(V.value()); + + if (!RSDOrErr) { + handleAllErrors(RSDOrErr.takeError(), [&](ErrorInfoBase &EIB) { + Ctx->emitError(EIB.message()); + }); + continue; + } + + auto &RSD = *RSDOrErr; + // Clang emits the root signature data in dxcontainer following a specific // sequence. First the header, then the root parameters. So the header // offset will always equal to the header size. @@ -127,12 +137,6 @@ analyzeModule(Module &M) { // static sampler offset is calculated when writting dxcontainer. RSD.StaticSamplersOffset = 0u; - hlsl::rootsig::MetadataParser MDParser(RootElementListNode); - - if (MDParser.ParseRootSignature(Ctx, RSD)) { - return RSDMap; - } - RSDMap.insert(std::make_pair(F, RSD)); } diff --git a/llvm/lib/Target/DirectX/MCTargetDesc/DirectXMCTargetDesc.cpp b/llvm/lib/Target/DirectX/MCTargetDesc/DirectXMCTargetDesc.cpp index 5323be6..9a14c01 100644 --- a/llvm/lib/Target/DirectX/MCTargetDesc/DirectXMCTargetDesc.cpp +++ b/llvm/lib/Target/DirectX/MCTargetDesc/DirectXMCTargetDesc.cpp @@ -78,8 +78,7 @@ public: ~DXILAsmBackend() override = default; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override {} + uint8_t *Data, uint64_t Value, bool IsResolved) override {} std::unique_ptr<MCObjectTargetWriter> createObjectTargetWriter() const override { diff --git a/llvm/lib/Target/Hexagon/MCTargetDesc/HexagonAsmBackend.cpp b/llvm/lib/Target/Hexagon/MCTargetDesc/HexagonAsmBackend.cpp index 7d3074b..1a0f1ab 100644 --- a/llvm/lib/Target/Hexagon/MCTargetDesc/HexagonAsmBackend.cpp +++ b/llvm/lib/Target/Hexagon/MCTargetDesc/HexagonAsmBackend.cpp @@ -402,8 +402,7 @@ public: } void applyFixup(const MCFragment &, const MCFixup &, const MCValue &, - MutableArrayRef<char> Data, uint64_t FixupValue, - bool IsResolved) override; + uint8_t *Data, uint64_t FixupValue, bool IsResolved) override; bool isInstRelaxable(MCInst const &HMI) const { const MCInstrDesc &MCID = HexagonMCInstrInfo::getDesc(*MCII, HMI); @@ -649,8 +648,7 @@ public: } // namespace void HexagonAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, + const MCValue &Target, uint8_t *InstAddr, uint64_t FixupValue, bool IsResolved) { if (IsResolved && shouldForceRelocation(Fixup)) IsResolved = false; @@ -667,10 +665,9 @@ void HexagonAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // LLVM gives us an encoded value, we have to convert it back // to a real offset before we can use it. - uint32_t Offset = Fixup.getOffset(); unsigned NumBytes = getFixupKindNumBytes(Kind); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); - char *InstAddr = Data.data() + Offset; + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); Value = adjustFixupValue(Kind, FixupValue); if (!Value) @@ -757,8 +754,8 @@ void HexagonAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, uint32_t OldData = 0; for (unsigned i = 0; i < NumBytes; i++) OldData |= (InstAddr[i] << (i * 8)) & (0xff << (i * 8)); dbgs() << "\tBValue=0x"; dbgs().write_hex(Value) << ": AValue=0x"; - dbgs().write_hex(FixupValue) - << ": Offset=" << Offset << ": Size=" << Data.size() << ": OInst=0x"; + dbgs().write_hex(FixupValue) << ": Offset=" << Fixup.getOffset() + << ": Size=" << F.getSize() << ": OInst=0x"; dbgs().write_hex(OldData) << ": Reloc=0x"; dbgs().write_hex(Reloc);); // For each byte of the fragment that the fixup touches, mask in the diff --git a/llvm/lib/Target/Lanai/MCTargetDesc/LanaiAsmBackend.cpp b/llvm/lib/Target/Lanai/MCTargetDesc/LanaiAsmBackend.cpp index 83d1697..3112dea 100644 --- a/llvm/lib/Target/Lanai/MCTargetDesc/LanaiAsmBackend.cpp +++ b/llvm/lib/Target/Lanai/MCTargetDesc/LanaiAsmBackend.cpp @@ -48,8 +48,7 @@ public: : MCAsmBackend(llvm::endianness::big), OSType(OST) {} void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; std::unique_ptr<MCObjectTargetWriter> createObjectTargetWriter() const override; @@ -72,9 +71,8 @@ bool LanaiAsmBackend::writeNopData(raw_ostream &OS, uint64_t Count, } void LanaiAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (!IsResolved) Asm->getWriter().recordRelocation(F, Fixup, Target, Value); @@ -85,7 +83,6 @@ void LanaiAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Where in the object and where the number of bytes that need // fixing up - unsigned Offset = Fixup.getOffset(); unsigned NumBytes = (getFixupKindInfo(Kind).TargetSize + 7) / 8; unsigned FullSize = 4; @@ -95,8 +92,7 @@ void LanaiAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Load instruction and apply value for (unsigned i = 0; i != NumBytes; ++i) { unsigned Idx = (FullSize - 1 - i); - CurVal |= static_cast<uint64_t>(static_cast<uint8_t>(Data[Offset + Idx])) - << (i * 8); + CurVal |= static_cast<uint64_t>(static_cast<uint8_t>(Data[Idx])) << (i * 8); } uint64_t Mask = @@ -106,7 +102,7 @@ void LanaiAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Write out the fixed up bytes back to the code/data bits. for (unsigned i = 0; i != NumBytes; ++i) { unsigned Idx = (FullSize - 1 - i); - Data[Offset + Idx] = static_cast<uint8_t>((CurVal >> (i * 8)) & 0xff); + Data[Idx] = static_cast<uint8_t>((CurVal >> (i * 8)) & 0xff); } } diff --git a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp index d9ea88c..fda9d97 100644 --- a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp +++ b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp @@ -131,19 +131,18 @@ static uint64_t adjustFixupValue(const MCFixup &Fixup, uint64_t Value, } } -static void fixupLeb128(MCContext &Ctx, const MCFixup &Fixup, - MutableArrayRef<char> Data, uint64_t Value) { +static void fixupLeb128(MCContext &Ctx, const MCFixup &Fixup, uint8_t *Data, + uint64_t Value) { unsigned I; - for (I = 0; I != Data.size() && Value; ++I, Value >>= 7) + for (I = 0; Value; ++I, Value >>= 7) Data[I] |= uint8_t(Value & 0x7f); if (Value) Ctx.reportError(Fixup.getLoc(), "Invalid uleb128 value!"); } void LoongArchAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (IsResolved && shouldForceRelocation(Fixup, Target)) IsResolved = false; IsResolved = addReloc(F, Fixup, Target, Value, IsResolved); @@ -166,14 +165,14 @@ void LoongArchAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Shift the value into position. Value <<= Info.TargetOffset; - unsigned Offset = Fixup.getOffset(); unsigned NumBytes = alignTo(Info.TargetSize + Info.TargetOffset, 8) / 8; - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. for (unsigned I = 0; I != NumBytes; ++I) { - Data[Offset + I] |= uint8_t((Value >> (I * 8)) & 0xff); + Data[I] |= uint8_t((Value >> (I * 8)) & 0xff); } } @@ -274,15 +273,14 @@ bool LoongArchAsmBackend::relaxDwarfLineAddr(MCFragment &F, int64_t LineDelta = F.getDwarfLineDelta(); const MCExpr &AddrDelta = F.getDwarfAddrDelta(); - SmallVector<MCFixup, 1> Fixups; size_t OldSize = F.getVarSize(); int64_t Value; if (AddrDelta.evaluateAsAbsolute(Value, *Asm)) return false; - bool IsAbsolute = AddrDelta.evaluateKnownAbsolute(Value, *Asm); - assert(IsAbsolute && "CFA with invalid expression"); - (void)IsAbsolute; + [[maybe_unused]] bool IsAbsolute = + AddrDelta.evaluateKnownAbsolute(Value, *Asm); + assert(IsAbsolute); SmallVector<char> Data; raw_svector_ostream OS(Data); @@ -293,33 +291,23 @@ bool LoongArchAsmBackend::relaxDwarfLineAddr(MCFragment &F, encodeSLEB128(LineDelta, OS); } - unsigned Offset; - std::pair<MCFixupKind, MCFixupKind> FK; - // According to the DWARF specification, the `DW_LNS_fixed_advance_pc` opcode // takes a single unsigned half (unencoded) operand. The maximum encodable // value is therefore 65535. Set a conservative upper bound for relaxation. + unsigned PCBytes; if (Value > 60000) { unsigned PtrSize = C.getAsmInfo()->getCodePointerSize(); - - OS << uint8_t(dwarf::DW_LNS_extended_op); - encodeULEB128(PtrSize + 1, OS); - - OS << uint8_t(dwarf::DW_LNE_set_address); - Offset = OS.tell(); assert((PtrSize == 4 || PtrSize == 8) && "Unexpected pointer size"); - FK = getRelocPairForSize(PtrSize == 4 ? 32 : 64); + PCBytes = PtrSize; + OS << uint8_t(dwarf::DW_LNS_extended_op) << uint8_t(PtrSize + 1) + << uint8_t(dwarf::DW_LNE_set_address); OS.write_zeros(PtrSize); } else { + PCBytes = 2; OS << uint8_t(dwarf::DW_LNS_fixed_advance_pc); - Offset = OS.tell(); - FK = getRelocPairForSize(16); support::endian::write<uint16_t>(OS, 0, llvm::endianness::little); } - - const MCBinaryExpr &MBE = cast<MCBinaryExpr>(AddrDelta); - Fixups.push_back(MCFixup::create(Offset, MBE.getLHS(), std::get<0>(FK))); - Fixups.push_back(MCFixup::create(Offset, MBE.getRHS(), std::get<1>(FK))); + auto Offset = OS.tell() - PCBytes; if (LineDelta == INT64_MAX) { OS << uint8_t(dwarf::DW_LNS_extended_op); @@ -330,7 +318,8 @@ bool LoongArchAsmBackend::relaxDwarfLineAddr(MCFragment &F, } F.setVarContents(Data); - F.setVarFixups(Fixups); + F.setVarFixups({MCFixup::create(Offset, &AddrDelta, + MCFixup::getDataKindForSize(PCBytes))}); WasRelaxed = OldSize != Data.size(); return true; } diff --git a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.h b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.h index 3d929fc..1f13601 100644 --- a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.h +++ b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.h @@ -42,8 +42,7 @@ public: uint64_t &FixedValue, bool IsResolved); void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool shouldForceRelocation(const MCFixup &Fixup, const MCValue &Target); diff --git a/llvm/lib/Target/M68k/MCTargetDesc/M68kAsmBackend.cpp b/llvm/lib/Target/M68k/MCTargetDesc/M68kAsmBackend.cpp index 5e03903..fe83dc6 100644 --- a/llvm/lib/Target/M68k/MCTargetDesc/M68kAsmBackend.cpp +++ b/llvm/lib/Target/M68k/MCTargetDesc/M68kAsmBackend.cpp @@ -53,8 +53,7 @@ public: .Default(false)) {} void applyFixup(const MCFragment &, const MCFixup &, const MCValue &, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool mayNeedRelaxation(unsigned Opcode, ArrayRef<MCOperand> Operands, const MCSubtargetInfo &STI) const override; @@ -78,14 +77,13 @@ public: } // end anonymous namespace void M68kAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (!IsResolved) Asm->getWriter().recordRelocation(F, Fixup, Target, Value); unsigned Size = 1 << getFixupKindLog2Size(Fixup.getKind()); - assert(Fixup.getOffset() + Size <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + Size <= F.getSize() && "Invalid fixup offset!"); // Check that uppper bits are either all zeros or all ones. // Specifically ignore overflow/underflow as long as the leakage is // limited to the lower bits. This is to remain compatible with @@ -95,8 +93,7 @@ void M68kAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Write in Big Endian for (unsigned i = 0; i != Size; ++i) - Data[Fixup.getOffset() + i] = - uint8_t(static_cast<int64_t>(Value) >> ((Size - i - 1) * 8)); + Data[i] = uint8_t(static_cast<int64_t>(Value) >> ((Size - i - 1) * 8)); } /// cc—Carry clear GE—Greater than or equal diff --git a/llvm/lib/Target/MSP430/MCTargetDesc/MSP430AsmBackend.cpp b/llvm/lib/Target/MSP430/MCTargetDesc/MSP430AsmBackend.cpp index 29e5bfa..d892b3a 100644 --- a/llvm/lib/Target/MSP430/MCTargetDesc/MSP430AsmBackend.cpp +++ b/llvm/lib/Target/MSP430/MCTargetDesc/MSP430AsmBackend.cpp @@ -36,8 +36,7 @@ public: ~MSP430AsmBackend() override = default; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; std::unique_ptr<MCObjectTargetWriter> createObjectTargetWriter() const override { @@ -105,9 +104,8 @@ uint64_t MSP430AsmBackend::adjustFixupValue(const MCFixup &Fixup, } void MSP430AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { maybeAddReloc(F, Fixup, Target, Value, IsResolved); Value = adjustFixupValue(Fixup, Value, getContext()); MCFixupKindInfo Info = getFixupKindInfo(Fixup.getKind()); @@ -117,15 +115,14 @@ void MSP430AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Shift the value into position. Value <<= Info.TargetOffset; - unsigned Offset = Fixup.getOffset(); unsigned NumBytes = alignTo(Info.TargetSize + Info.TargetOffset, 8) / 8; - - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. for (unsigned i = 0; i != NumBytes; ++i) { - Data[Offset + i] |= uint8_t((Value >> (i * 8)) & 0xff); + Data[i] |= uint8_t((Value >> (i * 8)) & 0xff); } } diff --git a/llvm/lib/Target/Mips/MCTargetDesc/MipsAsmBackend.cpp b/llvm/lib/Target/Mips/MCTargetDesc/MipsAsmBackend.cpp index c2169be..33aab71 100644 --- a/llvm/lib/Target/Mips/MCTargetDesc/MipsAsmBackend.cpp +++ b/llvm/lib/Target/Mips/MCTargetDesc/MipsAsmBackend.cpp @@ -283,9 +283,8 @@ static bool shouldForceRelocation(const MCFixup &Fixup) { /// data fragment, at the offset specified by the fixup and following the /// fixup kind as appropriate. void MipsAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (shouldForceRelocation(Fixup)) IsResolved = false; maybeAddReloc(F, Fixup, Target, Value, IsResolved); @@ -297,7 +296,6 @@ void MipsAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, return; // Doesn't change encoding. // Where do we start in the object - unsigned Offset = Fixup.getOffset(); // Number of bytes we need to fixup unsigned NumBytes = (getFixupKindInfo(Kind).TargetSize + 7) / 8; // Used to point to big endian bytes @@ -328,7 +326,7 @@ void MipsAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned Idx = Endian == llvm::endianness::little ? (microMipsLEByteOrder ? calculateMMLEIndex(i) : i) : (FullSize - 1 - i); - CurVal |= (uint64_t)((uint8_t)Data[Offset + Idx]) << (i*8); + CurVal |= (uint64_t)((uint8_t)Data[Idx]) << (i * 8); } uint64_t Mask = ((uint64_t)(-1) >> @@ -340,7 +338,7 @@ void MipsAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, unsigned Idx = Endian == llvm::endianness::little ? (microMipsLEByteOrder ? calculateMMLEIndex(i) : i) : (FullSize - 1 - i); - Data[Offset + Idx] = (uint8_t)((CurVal >> (i*8)) & 0xff); + Data[Idx] = (uint8_t)((CurVal >> (i * 8)) & 0xff); } } diff --git a/llvm/lib/Target/Mips/MCTargetDesc/MipsAsmBackend.h b/llvm/lib/Target/Mips/MCTargetDesc/MipsAsmBackend.h index 816626d..40b5853 100644 --- a/llvm/lib/Target/Mips/MCTargetDesc/MipsAsmBackend.h +++ b/llvm/lib/Target/Mips/MCTargetDesc/MipsAsmBackend.h @@ -40,8 +40,7 @@ public: createObjectTargetWriter() const override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; std::optional<MCFixupKind> getFixupKind(StringRef Name) const override; MCFixupKindInfo getFixupKindInfo(MCFixupKind Kind) const override; diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp index ec6b382..881ba8e 100644 --- a/llvm/lib/Target/Mips/MipsISelLowering.cpp +++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp @@ -3341,6 +3341,7 @@ MipsTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, bool &IsTailCall = CLI.IsTailCall; CallingConv::ID CallConv = CLI.CallConv; bool IsVarArg = CLI.IsVarArg; + const CallBase *CB = CLI.CB; MachineFunction &MF = DAG.getMachineFunction(); MachineFrameInfo &MFI = MF.getFrameInfo(); @@ -3397,8 +3398,11 @@ MipsTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Get a count of how many bytes are to be pushed on the stack. unsigned StackSize = CCInfo.getStackSize(); - // Call site info for function parameters tracking. + // Call site info for function parameters tracking and call base type info. MachineFunction::CallSiteInfo CSInfo; + // Set type id for call site info. + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + CSInfo = MachineFunction::CallSiteInfo(*CB); // Check if it's really possible to do a tail call. Restrict it to functions // that are part of this compilation unit. diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index 8eec915..ee1ca45 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -391,16 +391,6 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, } } -void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum, - raw_ostream &O) { - auto &Op = MI->getOperand(OpNum); - assert(Op.isImm() && "Invalid operand"); - if (Op.getImm() != 0) { - O << "+"; - printOperand(MI, OpNum, O); - } -} - void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O) { int64_t Imm = MI->getOperand(OpNum).getImm(); diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h index c3ff346..92155b0 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h @@ -46,7 +46,6 @@ public: StringRef Modifier = {}); void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O, StringRef Modifier = {}); - void printOffseti32imm(const MCInst *MI, int OpNum, raw_ostream &O); void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O); void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O); void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O); diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp index cd40481..a349609 100644 --- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp @@ -56,15 +56,12 @@ static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI, case NVPTX::LD_i16: case NVPTX::LD_i32: case NVPTX::LD_i64: - case NVPTX::LD_i8: case NVPTX::LDV_i16_v2: case NVPTX::LDV_i16_v4: case NVPTX::LDV_i32_v2: case NVPTX::LDV_i32_v4: case NVPTX::LDV_i64_v2: - case NVPTX::LDV_i64_v4: - case NVPTX::LDV_i8_v2: - case NVPTX::LDV_i8_v4: { + case NVPTX::LDV_i64_v4: { LoadInsts.push_back(&U); return true; } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 96f52275..6068035 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -56,9 +56,7 @@ INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy, DEBUG_TYPE, PASS_NAME, false, false) NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm, CodeGenOptLevel OptLevel) - : SelectionDAGISel(tm, OptLevel), TM(tm) { - doMulWide = (OptLevel > CodeGenOptLevel::None); -} + : SelectionDAGISel(tm, OptLevel), TM(tm) {} bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) { Subtarget = &MF.getSubtarget<NVPTXSubtarget>(); @@ -1005,14 +1003,10 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) { // Helper function template to reduce amount of boilerplate code for // opcode selection. static std::optional<unsigned> -pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8, - std::optional<unsigned> Opcode_i16, +pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i16, std::optional<unsigned> Opcode_i32, std::optional<unsigned> Opcode_i64) { switch (VT) { - case MVT::i1: - case MVT::i8: - return Opcode_i8; case MVT::f16: case MVT::i16: case MVT::bf16: @@ -1080,8 +1074,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { Chain}; const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; - const std::optional<unsigned> Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_i8, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64); + const std::optional<unsigned> Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64); if (!Opcode) return false; @@ -1166,17 +1160,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { default: llvm_unreachable("Unexpected opcode"); case NVPTXISD::LoadV2: - Opcode = - pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v2, NVPTX::LDV_i16_v2, - NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2); + Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v2, + NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2); break; case NVPTXISD::LoadV4: - Opcode = - pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v4, NVPTX::LDV_i16_v4, - NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4); + Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v4, + NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4); break; case NVPTXISD::LoadV8: - Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i8 */}, {/* no v8i16 */}, + Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i16 */}, NVPTX::LDV_i32_v8, {/* no v8i64 */}); break; } @@ -1232,22 +1224,21 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { default: llvm_unreachable("Unexpected opcode"); case ISD::LOAD: - Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i8, - NVPTX::LD_GLOBAL_NC_i16, NVPTX::LD_GLOBAL_NC_i32, - NVPTX::LD_GLOBAL_NC_i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16, + NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64); break; case NVPTXISD::LoadV2: - Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_GLOBAL_NC_v2i8, NVPTX::LD_GLOBAL_NC_v2i16, - NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64); + Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16, + NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64); break; case NVPTXISD::LoadV4: - Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_GLOBAL_NC_v4i8, NVPTX::LD_GLOBAL_NC_v4i16, - NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64); + Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v4i16, + NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64); break; case NVPTXISD::LoadV8: - Opcode = pickOpcodeForVT(TargetVT, {/* no v8i8 */}, {/* no v8i16 */}, + Opcode = pickOpcodeForVT(TargetVT, {/* no v8i16 */}, NVPTX::LD_GLOBAL_NC_v8i32, {/* no v8i64 */}); break; } @@ -1278,8 +1269,9 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) { break; } - const MVT::SimpleValueType SelectVT = - MVT::getIntegerVT(LD->getMemoryVT().getSizeInBits() / NumElts).SimpleTy; + SDLoc DL(N); + const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts; + const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; // If this is an LDU intrinsic, the address is the third operand. If its an // LDU SD node (from custom vector handling), then its the second operand @@ -1288,32 +1280,28 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) { SDValue Base, Offset; SelectADDR(Addr, Base, Offset); - SDValue Ops[] = {Base, Offset, LD->getChain()}; + SDValue Ops[] = {getI32Imm(FromTypeWidth, DL), Base, Offset, LD->getChain()}; std::optional<unsigned> Opcode; switch (N->getOpcode()) { default: llvm_unreachable("Unexpected opcode"); case ISD::INTRINSIC_W_CHAIN: - Opcode = - pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_i8, NVPTX::LDU_GLOBAL_i16, - NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_i16, + NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64); break; case NVPTXISD::LDUV2: - Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v2i8, - NVPTX::LDU_GLOBAL_v2i16, NVPTX::LDU_GLOBAL_v2i32, - NVPTX::LDU_GLOBAL_v2i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v2i16, + NVPTX::LDU_GLOBAL_v2i32, NVPTX::LDU_GLOBAL_v2i64); break; case NVPTXISD::LDUV4: - Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v4i8, - NVPTX::LDU_GLOBAL_v4i16, NVPTX::LDU_GLOBAL_v4i32, - {/* no v4i64 */}); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v4i16, + NVPTX::LDU_GLOBAL_v4i32, {/* no v4i64 */}); break; } if (!Opcode) return false; - SDLoc DL(N); SDNode *NVPTXLDU = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops); ReplaceNode(LD, NVPTXLDU); @@ -1364,8 +1352,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { Chain}; const std::optional<unsigned> Opcode = - pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i8, - NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64); + pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i16, + NVPTX::ST_i32, NVPTX::ST_i64); if (!Opcode) return false; @@ -1425,16 +1413,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { default: return false; case NVPTXISD::StoreV2: - Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v2, NVPTX::STV_i16_v2, - NVPTX::STV_i32_v2, NVPTX::STV_i64_v2); + Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v2, NVPTX::STV_i32_v2, + NVPTX::STV_i64_v2); break; case NVPTXISD::StoreV4: - Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v4, NVPTX::STV_i16_v4, - NVPTX::STV_i32_v4, NVPTX::STV_i64_v4); + Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, + NVPTX::STV_i64_v4); break; case NVPTXISD::StoreV8: - Opcode = pickOpcodeForVT(EltVT, {/* no v8i8 */}, {/* no v8i16 */}, - NVPTX::STV_i32_v8, {/* no v8i64 */}); + Opcode = pickOpcodeForVT(EltVT, {/* no v8i16 */}, NVPTX::STV_i32_v8, + {/* no v8i64 */}); break; } @@ -1689,10 +1677,11 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) { auto API = APF.bitcastToAPInt(); API = API.concat(API); auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32); - return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32i, DL, VT, Const), 0); + return SDValue(CurDAG->getMachineNode(NVPTX::MOV_B32_i, DL, VT, Const), + 0); } auto Const = CurDAG->getTargetConstantFP(APF, DL, VT); - return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16i, DL, VT, Const), 0); + return SDValue(CurDAG->getMachineNode(NVPTX::MOV_BF16_i, DL, VT, Const), 0); }; switch (N->getOpcode()) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index e504a8f..9e0f88e5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -40,9 +40,6 @@ private: class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel { const NVPTXTargetMachine &TM; - // If true, generate mul.wide from sext and mul - bool doMulWide; - NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const; bool usePrecSqrtF32(const SDNode *N) const; bool useF32FTZ() const; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index f79b862..15f45a1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -382,6 +382,54 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, } } +// We return an EVT that can hold N VTs +// If the VT is a vector, the resulting EVT is a flat vector with the same +// element type as VT's element type. +static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) { + if (N == 1) + return VT; + + return VT.isVector() ? EVT::getVectorVT(C, VT.getScalarType(), + VT.getVectorNumElements() * N) + : EVT::getVectorVT(C, VT, N); +} + +static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT, + const SDLoc &dl, SelectionDAG &DAG) { + if (V.getValueType() == VT) { + assert(I == 0 && "Index must be 0 for scalar value"); + return V; + } + + if (!VT.isVector()) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V, + DAG.getVectorIdxConstant(I, dl)); + + return DAG.getNode( + ISD::EXTRACT_SUBVECTOR, dl, VT, V, + DAG.getVectorIdxConstant(I * VT.getVectorNumElements(), dl)); +} + +template <typename T> +static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl, + SelectionDAG &DAG, T GetElement) { + if (N == 1) + return GetElement(0); + + SmallVector<SDValue, 8> Values; + for (const unsigned I : llvm::seq(N)) { + SDValue Val = GetElement(I); + if (Val.getValueType().isVector()) + DAG.ExtractVectorElements(Val, Values); + else + Values.push_back(Val); + } + + EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(), + Values.size()); + return DAG.getBuildVector(VT, dl, Values); +} + /// PromoteScalarIntegerPTX /// Used to make sure the arguments/returns are suitable for passing /// and promote them to a larger size if they're not. @@ -420,9 +468,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) { // parameter starting at index Idx using a single vectorized op of // size AccessSize. If so, it returns the number of param pieces // covered by the vector op. Otherwise, it returns 1. -static unsigned CanMergeParamLoadStoresStartingAt( +template <typename T> +static unsigned canMergeParamLoadStoresStartingAt( unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs, - const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) { + const SmallVectorImpl<T> &Offsets, Align ParamAlignment) { // Can't vectorize if param alignment is not sufficient. if (ParamAlignment < AccessSize) @@ -472,10 +521,11 @@ static unsigned CanMergeParamLoadStoresStartingAt( // of the same size as ValueVTs indicating how each piece should be // loaded/stored (i.e. as a scalar, or as part of a vector // load/store). +template <typename T> static SmallVector<unsigned, 16> VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs, - const SmallVectorImpl<uint64_t> &Offsets, - Align ParamAlignment, bool IsVAArg = false) { + const SmallVectorImpl<T> &Offsets, Align ParamAlignment, + bool IsVAArg = false) { // Set vector size to match ValueVTs and mark all elements as // scalars by default. @@ -486,7 +536,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs, const auto GetNumElts = [&](unsigned I) -> unsigned { for (const unsigned AccessSize : {16, 8, 4, 2}) { - const unsigned NumElts = CanMergeParamLoadStoresStartingAt( + const unsigned NumElts = canMergeParamLoadStoresStartingAt( I, AccessSize, ValueVTs, Offsets, ParamAlignment); assert((NumElts == 1 || NumElts == 2 || NumElts == 4) && "Unexpected vectorization size"); @@ -843,7 +893,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD, - ISD::STORE}); + ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -1384,6 +1434,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Type *RetTy = CLI.RetTy; const CallBase *CB = CLI.CB; const DataLayout &DL = DAG.getDataLayout(); + LLVMContext &Ctx = *DAG.getContext(); const auto GetI32 = [&](const unsigned I) { return DAG.getConstant(I, dl, MVT::i32); @@ -1476,15 +1527,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const SDValue ParamSymbol = getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32); - SmallVector<EVT, 16> VTs; - SmallVector<uint64_t, 16> Offsets; - assert((!IsByVal || Arg.IndirectType) && "byval arg must have indirect type"); Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty); - ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset); - assert(VTs.size() == Offsets.size() && "Size mismatch"); - assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch"); const Align ArgAlign = [&]() { if (IsByVal) { @@ -1492,17 +1537,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // so we don't need to worry whether it's naturally aligned or not. // See TargetLowering::LowerCallTo(). const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign(); - const Align ByValAlign = getFunctionByValParamAlign( - CB->getCalledFunction(), ETy, InitialAlign, DL); - if (IsVAArg) - VAOffset = alignTo(VAOffset, ByValAlign); - return ByValAlign; + return getFunctionByValParamAlign(CB->getCalledFunction(), ETy, + InitialAlign, DL); } return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL); }(); - const unsigned TypeSize = DL.getTypeAllocSize(ETy); - assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) && + const unsigned TySize = DL.getTypeAllocSize(ETy); + assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) && "type size mismatch"); const SDValue ArgDeclare = [&]() { @@ -1510,105 +1552,120 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, return VADeclareParam; if (IsByVal || shouldPassAsArray(Arg.Ty)) - return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize); + return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize); assert(ArgOuts.size() == 1 && "We must pass only one value as non-array"); assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) && "Only int and float types are supported as non-array arguments"); - return MakeDeclareScalarParam(ParamSymbol, TypeSize); + return MakeDeclareScalarParam(ParamSymbol, TySize); }(); - // PTX Interoperability Guide 3.3(A): [Integer] Values shorter - // than 32-bits are sign extended or zero extended, depending on - // whether they are signed or unsigned types. This case applies - // only to scalar parameters and not to aggregate values. - const bool ExtendIntegerParam = - Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32; - - const auto GetStoredValue = [&](const unsigned I, EVT EltVT, - const MaybeAlign PartAlign) { - if (IsByVal) { - SDValue Ptr = ArgOutVals[0]; - auto MPI = refinePtrAS(Ptr, DAG, DL, *this); - SDValue SrcAddr = - DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I])); + if (IsByVal) { + assert(ArgOutVals.size() == 1 && "We must pass only one value as byval"); + SDValue SrcPtr = ArgOutVals[0]; + const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this); + const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign(); - return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign); + if (IsVAArg) + VAOffset = alignTo(VAOffset, ArgAlign); + + SmallVector<EVT, 4> ValueVTs, MemVTs; + SmallVector<TypeSize, 4> Offsets; + ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets); + + unsigned J = 0; + const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg); + for (const unsigned NumElts : VI) { + EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx); + Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]); + SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]); + SDValue SrcLoad = + DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign); + + TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset); + Align ParamAlign = commonAlignment(ArgAlign, ParamOffset); + SDValue ParamAddr = + DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset); + SDValue StoreParam = + DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr, + MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign); + CallPrereqs.push_back(StoreParam); + + J += NumElts; } - SDValue StVal = ArgOutVals[I]; - assert(promoteScalarIntegerPTX(StVal.getValueType()) == - StVal.getValueType() && - "OutVal type should always be legal"); - - const EVT VTI = promoteScalarIntegerPTX(VTs[I]); - const EVT StoreVT = - ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI); - - return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl); - }; - - const auto VectorInfo = - VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg); - - unsigned J = 0; - for (const unsigned NumElts : VectorInfo) { - const int CurOffset = Offsets[J]; - const EVT EltVT = promoteScalarIntegerPTX(VTs[J]); - - if (IsVAArg && !IsByVal) - // Align each part of the variadic argument to their type. - VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT)); - - assert((IsVAArg || VAOffset == 0) && - "VAOffset must be 0 for non-VA args"); + if (IsVAArg) + VAOffset += TySize; + } else { + SmallVector<EVT, 16> VTs; + SmallVector<uint64_t, 16> Offsets; + ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset); + assert(VTs.size() == Offsets.size() && "Size mismatch"); + assert(VTs.size() == ArgOuts.size() && "Size mismatch"); + + // PTX Interoperability Guide 3.3(A): [Integer] Values shorter + // than 32-bits are sign extended or zero extended, depending on + // whether they are signed or unsigned types. This case applies + // only to scalar parameters and not to aggregate values. + const bool ExtendIntegerParam = + Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32; + + const auto GetStoredValue = [&](const unsigned I) { + SDValue StVal = ArgOutVals[I]; + assert(promoteScalarIntegerPTX(StVal.getValueType()) == + StVal.getValueType() && + "OutVal type should always be legal"); + + const EVT VTI = promoteScalarIntegerPTX(VTs[I]); + const EVT StoreVT = + ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI); + + return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl); + }; + + unsigned J = 0; + const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg); + for (const unsigned NumElts : VI) { + const EVT EltVT = promoteScalarIntegerPTX(VTs[J]); + + unsigned Offset; + if (IsVAArg) { + // TODO: We may need to support vector types that can be passed + // as scalars in variadic arguments. + assert(NumElts == 1 && + "Vectorization should be disabled for vaargs."); + + // Align each part of the variadic argument to their type. + VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT)); + Offset = VAOffset; + + const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT; + VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx)); + } else { + assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args"); + Offset = Offsets[J]; + } - const unsigned Offset = - (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset)); - SDValue Ptr = - DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset)); + SDValue Ptr = + DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset)); - const MaybeAlign CurrentAlign = ExtendIntegerParam - ? MaybeAlign(std::nullopt) - : commonAlignment(ArgAlign, Offset); + const MaybeAlign CurrentAlign = ExtendIntegerParam + ? MaybeAlign(std::nullopt) + : commonAlignment(ArgAlign, Offset); - SDValue Val; - if (NumElts == 1) { - Val = GetStoredValue(J, EltVT, CurrentAlign); - } else { - SmallVector<SDValue, 8> StoreVals; - for (const unsigned K : llvm::seq(NumElts)) { - SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign); - if (ValJ.getValueType().isVector()) - DAG.ExtractVectorElements(ValJ, StoreVals); - else - StoreVals.push_back(ValJ); - } + SDValue Val = + getBuildVectorizedValue(NumElts, dl, DAG, [&](unsigned K) { + return GetStoredValue(J + K); + }); - EVT VT = EVT::getVectorVT( - *DAG.getContext(), StoreVals[0].getValueType(), StoreVals.size()); - Val = DAG.getBuildVector(VT, dl, StoreVals); - } + SDValue StoreParam = + DAG.getStore(ArgDeclare, dl, Val, Ptr, + MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign); + CallPrereqs.push_back(StoreParam); - SDValue StoreParam = - DAG.getStore(ArgDeclare, dl, Val, Ptr, - MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign); - CallPrereqs.push_back(StoreParam); - - // TODO: We may need to support vector types that can be passed - // as scalars in variadic arguments. - if (IsVAArg && !IsByVal) { - assert(NumElts == 1 && - "Vectorization is expected to be disabled for variadics."); - const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT; - VAOffset += - DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext())); + J += NumElts; } - - J += NumElts; } - if (IsVAArg && IsByVal) - VAOffset += TypeSize; } // Handle Result @@ -1676,17 +1733,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, CallPrereqs.push_back(PrototypeDeclare); } - if (ConvertToIndirectCall) { - // Copy the function ptr to a ptx register and use the register to call the - // function. - const MVT DestVT = Callee.getValueType().getSimpleVT(); - MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT)); - auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee); - Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT); - } - const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0; const unsigned NumArgs = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size()); @@ -1703,10 +1749,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, if (!Ins.empty()) { SmallVector<EVT, 16> VTs; SmallVector<uint64_t, 16> Offsets; - ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0); + ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); assert(VTs.size() == Ins.size() && "Bad value decomposition"); const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL); + const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32); // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than // 32-bits are sign extended or zero extended, depending on whether @@ -1714,9 +1761,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const bool ExtendIntegerRetVal = RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; - const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign); unsigned I = 0; - for (const unsigned NumElts : VectorInfo) { + const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign); + for (const unsigned NumElts : VI) { const MaybeAlign CurrentAlign = ExtendIntegerRetVal ? MaybeAlign(std::nullopt) : commonAlignment(RetAlign, Offsets[I]); @@ -1724,16 +1771,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const EVT VTI = promoteScalarIntegerPTX(VTs[I]); const EVT LoadVT = ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI); - - const unsigned PackingAmt = - LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1; - - const EVT VecVT = NumElts == 1 ? LoadVT - : EVT::getVectorVT(*DAG.getContext(), - LoadVT.getScalarType(), - NumElts * PackingAmt); - - const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32); + const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx); SDValue Ptr = DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I])); @@ -1742,17 +1780,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign); LoadChains.push_back(R.getValue(1)); - - if (NumElts == 1) - ProxyRegOps.push_back(R); - else - for (const unsigned J : llvm::seq(NumElts)) { - SDValue Elt = DAG.getNode( - LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR - : ISD::EXTRACT_VECTOR_ELT, - dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl)); - ProxyRegOps.push_back(Elt); - } + for (const unsigned J : llvm::seq(NumElts)) + ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG)); I += NumElts; } } @@ -3227,11 +3256,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl, SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const { - MachineFunction &MF = DAG.getMachineFunction(); const DataLayout &DL = DAG.getDataLayout(); auto PtrVT = getPointerTy(DAG.getDataLayout()); - const Function *F = &MF.getFunction(); + const Function &F = DAG.getMachineFunction().getFunction(); SDValue Root = DAG.getRoot(); SmallVector<SDValue, 16> OutChains; @@ -3247,7 +3275,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // See similar issue in LowerCall. auto AllIns = ArrayRef(Ins); - for (const auto &Arg : F->args()) { + for (const auto &Arg : F.args()) { const auto ArgIns = AllIns.take_while( [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); }); AllIns = AllIns.drop_front(ArgIns.size()); @@ -3287,7 +3315,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer"); SDValue P; - if (isKernelFunction(*F)) { + if (isKernelFunction(F)) { P = ArgSymbol; P.getNode()->setIROrder(Arg.getArgNo() + 1); } else { @@ -3305,43 +3333,27 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( assert(VTs.size() == Offsets.size() && "Size mismatch"); const Align ArgAlign = getFunctionArgumentAlignment( - F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL); + &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL); - const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); unsigned I = 0; - for (const unsigned NumElts : VectorInfo) { + const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); + for (const unsigned NumElts : VI) { // i1 is loaded/stored as i8 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I]; - // If the element is a packed type (ex. v2f16, v4i8, etc) holding - // multiple elements. - const unsigned PackingAmt = - LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1; - - const EVT VecVT = - NumElts == 1 - ? LoadVT - : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(), - NumElts * PackingAmt); + const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext()); SDValue VecAddr = DAG.getObjectPtrOffset( dl, ArgSymbol, TypeSize::getFixed(Offsets[I])); - const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]); + const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]); SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr, MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign, MachineMemOperand::MODereferenceable | MachineMemOperand::MOInvariant); - if (P.getNode()) - P.getNode()->setIROrder(Arg.getArgNo() + 1); + P.getNode()->setIROrder(Arg.getArgNo() + 1); for (const unsigned J : llvm::seq(NumElts)) { - SDValue Elt = - NumElts == 1 - ? P - : DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR - : ISD::EXTRACT_VECTOR_ELT, - dl, LoadVT, P, - DAG.getVectorIdxConstant(J * PackingAmt, dl)); + SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG); Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags, DAG, dl); @@ -3364,9 +3376,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, const SmallVectorImpl<ISD::OutputArg> &Outs, const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl, SelectionDAG &DAG) const { - const MachineFunction &MF = DAG.getMachineFunction(); - const Function &F = MF.getFunction(); - Type *RetTy = MF.getFunction().getReturnType(); + const Function &F = DAG.getMachineFunction().getFunction(); + Type *RetTy = F.getReturnType(); if (RetTy->isVoidTy()) { assert(OutVals.empty() && Outs.empty() && "Return value expected for void"); @@ -3374,10 +3385,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, } const DataLayout &DL = DAG.getDataLayout(); - SmallVector<EVT, 16> VTs; - SmallVector<uint64_t, 16> Offsets; - ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); - assert(VTs.size() == OutVals.size() && "Bad return value decomposition"); + + const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32); + const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL); // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than // 32-bits are sign extended or zero extended, depending on whether @@ -3385,6 +3395,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, const bool ExtendIntegerRetVal = RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; + SmallVector<EVT, 16> VTs; + SmallVector<uint64_t, 16> Offsets; + ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); + assert(VTs.size() == OutVals.size() && "Bad return value decomposition"); + const auto GetRetVal = [&](unsigned I) -> SDValue { SDValue RetVal = OutVals[I]; assert(promoteScalarIntegerPTX(RetVal.getValueType()) == @@ -3397,33 +3412,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl); }; - const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL); - const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign); unsigned I = 0; - for (const unsigned NumElts : VectorInfo) { + const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign); + for (const unsigned NumElts : VI) { const MaybeAlign CurrentAlign = ExtendIntegerRetVal ? MaybeAlign(std::nullopt) : commonAlignment(RetAlign, Offsets[I]); - SDValue Val; - if (NumElts == 1) { - Val = GetRetVal(I); - } else { - SmallVector<SDValue, 4> StoreVals; - for (const unsigned J : llvm::seq(NumElts)) { - SDValue ValJ = GetRetVal(I + J); - if (ValJ.getValueType().isVector()) - DAG.ExtractVectorElements(ValJ, StoreVals); - else - StoreVals.push_back(ValJ); - } - - EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(), - StoreVals.size()); - Val = DAG.getBuildVector(VT, dl, StoreVals); - } + SDValue Val = getBuildVectorizedValue( + NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); }); - const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32); SDValue Ptr = DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I])); @@ -4917,7 +4915,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { return SDValue(); auto *LD = cast<MemSDNode>(N); - EVT MemVT = LD->getMemoryVT(); SDLoc DL(LD); // the new opcode after we double the number of operands @@ -4958,9 +4955,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end()); // Create the new load - SDValue NewLoad = - DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs), - Operands, MemVT, LD->getMemOperand()); + SDValue NewLoad = DCI.DAG.getMemIntrinsicNode( + Opcode, DL, DCI.DAG.getVTList(NewVTs), Operands, LD->getMemoryVT(), + LD->getMemOperand()); // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep // the outputs the same. These nodes will be optimized away in later @@ -5002,7 +4999,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N, return SDValue(); auto *ST = cast<MemSDNode>(N); - EVT MemVT = ElementVT.getVectorElementType(); // The new opcode after we double the number of operands. NVPTXISD::NodeType Opcode; @@ -5011,11 +5007,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N, // Any packed type is legal, so the legalizer will not have lowered // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do // it here. - MemVT = ST->getMemoryVT(); Opcode = NVPTXISD::StoreV2; break; case NVPTXISD::StoreV2: - MemVT = ST->getMemoryVT(); Opcode = NVPTXISD::StoreV4; break; case NVPTXISD::StoreV4: @@ -5066,7 +5060,7 @@ static SDValue combinePackingMovIntoStore(SDNode *N, // Now we replace the store return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands, - MemVT, ST->getMemOperand()); + ST->getMemoryVT(), ST->getMemOperand()); } static SDValue PerformStoreCombine(SDNode *N, @@ -5219,6 +5213,42 @@ static SDValue PerformREMCombine(SDNode *N, return SDValue(); } +// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y) +static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + CodeGenOptLevel OptLevel) { + if (OptLevel == CodeGenOptLevel::None) + return SDValue(); + + SDValue Op = N->getOperand(0); + if (!Op.hasOneUse()) + return SDValue(); + EVT ToVT = N->getValueType(0); + EVT FromVT = Op.getValueType(); + if (!((ToVT == MVT::i32 && FromVT == MVT::i16) || + (ToVT == MVT::i64 && FromVT == MVT::i32))) + return SDValue(); + if (!(Op.getOpcode() == ISD::MUL || + (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Op.getOperand(1))))) + return SDValue(); + + SDLoc DL(N); + unsigned ExtOpcode = N->getOpcode(); + unsigned Opcode = 0; + if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap()) + Opcode = NVPTXISD::MUL_WIDE_SIGNED; + else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap()) + Opcode = NVPTXISD::MUL_WIDE_UNSIGNED; + else + return SDValue(); + SDValue RHS = Op.getOperand(1); + if (Op.getOpcode() == ISD::SHL) { + const auto ShiftAmt = Op.getConstantOperandVal(1); + const auto MulVal = APInt(ToVT.getSizeInBits(), 1) << ShiftAmt; + RHS = DCI.DAG.getConstant(MulVal, DL, ToVT); + } + return DCI.DAG.getNode(Opcode, DL, ToVT, Op.getOperand(0), RHS); +} + enum OperandSignedness { Signed = 0, Unsigned, @@ -5825,6 +5855,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return combineADDRSPACECAST(N, DCI); case ISD::AND: return PerformANDCombine(N, DCI); + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + return combineMulWide(N, DCI, OptLevel); case ISD::BUILD_VECTOR: return PerformBUILD_VECTORCombine(N, DCI); case ISD::EXTRACT_VECTOR_ELT: diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td index 86dcb4a..719be03 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td @@ -11,15 +11,9 @@ // //===----------------------------------------------------------------------===// -// Vector instruction type enum -class VecInstTypeEnum<bits<4> val> { - bits<4> Value=val; -} -def VecNOP : VecInstTypeEnum<0>; - // Generic NVPTX Format -class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern> +class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern = []> : Instruction { field bits<14> Inst; @@ -30,7 +24,6 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern> let Pattern = pattern; // TSFlagFields - bits<4> VecInstType = VecNOP.Value; bit IsLoad = false; bit IsStore = false; @@ -45,7 +38,6 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern> // 2**(2-1) = 2. bits<2> IsSuld = 0; - let TSFlags{3...0} = VecInstType; let TSFlags{4} = IsLoad; let TSFlags{5} = IsStore; let TSFlags{6} = IsTex; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp index e218ef1..34fe467 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -35,23 +35,23 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB, const TargetRegisterClass *DestRC = MRI.getRegClass(DestReg); const TargetRegisterClass *SrcRC = MRI.getRegClass(SrcReg); - if (RegInfo.getRegSizeInBits(*DestRC) != RegInfo.getRegSizeInBits(*SrcRC)) + if (DestRC != SrcRC) report_fatal_error("Copy one register into another with a different width"); unsigned Op; - if (DestRC == &NVPTX::B1RegClass) { - Op = NVPTX::IMOV1r; - } else if (DestRC == &NVPTX::B16RegClass) { - Op = NVPTX::MOV16r; - } else if (DestRC == &NVPTX::B32RegClass) { - Op = NVPTX::IMOV32r; - } else if (DestRC == &NVPTX::B64RegClass) { - Op = NVPTX::IMOV64r; - } else if (DestRC == &NVPTX::B128RegClass) { - Op = NVPTX::IMOV128r; - } else { + if (DestRC == &NVPTX::B1RegClass) + Op = NVPTX::MOV_B1_r; + else if (DestRC == &NVPTX::B16RegClass) + Op = NVPTX::MOV_B16_r; + else if (DestRC == &NVPTX::B32RegClass) + Op = NVPTX::MOV_B32_r; + else if (DestRC == &NVPTX::B64RegClass) + Op = NVPTX::MOV_B64_r; + else if (DestRC == &NVPTX::B128RegClass) + Op = NVPTX::MOV_B128_r; + else llvm_unreachable("Bad register copy"); - } + BuildMI(MBB, I, DL, get(Op), DestReg) .addReg(SrcReg, getKillRegState(KillSrc)); } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 86d6f7c..2ae7520 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -15,19 +15,8 @@ include "NVPTXInstrFormats.td" let OperandType = "OPERAND_IMMEDIATE" in { def f16imm : Operand<f16>; def bf16imm : Operand<bf16>; - } -// List of vector specific properties -def isVecLD : VecInstTypeEnum<1>; -def isVecST : VecInstTypeEnum<2>; -def isVecBuild : VecInstTypeEnum<3>; -def isVecShuffle : VecInstTypeEnum<4>; -def isVecExtract : VecInstTypeEnum<5>; -def isVecInsert : VecInstTypeEnum<6>; -def isVecDest : VecInstTypeEnum<7>; -def isVecOther : VecInstTypeEnum<15>; - //===----------------------------------------------------------------------===// // NVPTX Operand Definitions. //===----------------------------------------------------------------------===// @@ -125,8 +114,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">; def doNoF32FTZ : Predicate<"!useF32FTZ()">; def doRsqrtOpt : Predicate<"doRsqrtOpt()">; -def doMulWide : Predicate<"doMulWide">; - def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">; def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">; def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">; @@ -486,46 +473,28 @@ let hasSideEffects = false in { // takes a CvtMode immediate that defines the conversion mode to use. It can // be CvtNONE to omit a conversion mode. multiclass CVT_FROM_ALL<string ToType, RegisterClass RC, list<Predicate> Preds = []> { - def _s8 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s8">, - Requires<Preds>; - def _u8 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u8">, - Requires<Preds>; - def _s16 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s16">, - Requires<Preds>; - def _u16 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u16">, - Requires<Preds>; - def _s32 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B32:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s32">, - Requires<Preds>; - def _u32 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B32:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u32">, - Requires<Preds>; - def _s64 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B64:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s64">, - Requires<Preds>; - def _u64 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B64:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u64">, - Requires<Preds>; + foreach sign = ["s", "u"] in { + def _ # sign # "8" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B16:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "8">, + Requires<Preds>; + def _ # sign # "16" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B16:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "16">, + Requires<Preds>; + def _ # sign # "32" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B32:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "32">, + Requires<Preds>; + def _ # sign # "64" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B64:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "64">, + Requires<Preds>; + } def _f16 : BasicFlagsNVPTXInst<(outs RC:$dst), (ins B16:$src), (ins CvtMode:$mode), @@ -556,14 +525,12 @@ let hasSideEffects = false in { } // Generate cvts from all types to all types. - defm CVT_s8 : CVT_FROM_ALL<"s8", B16>; - defm CVT_u8 : CVT_FROM_ALL<"u8", B16>; - defm CVT_s16 : CVT_FROM_ALL<"s16", B16>; - defm CVT_u16 : CVT_FROM_ALL<"u16", B16>; - defm CVT_s32 : CVT_FROM_ALL<"s32", B32>; - defm CVT_u32 : CVT_FROM_ALL<"u32", B32>; - defm CVT_s64 : CVT_FROM_ALL<"s64", B64>; - defm CVT_u64 : CVT_FROM_ALL<"u64", B64>; + foreach sign = ["s", "u"] in { + defm CVT_ # sign # "8" : CVT_FROM_ALL<sign # "8", B16>; + defm CVT_ # sign # "16" : CVT_FROM_ALL<sign # "16", B16>; + defm CVT_ # sign # "32" : CVT_FROM_ALL<sign # "32", B32>; + defm CVT_ # sign # "64" : CVT_FROM_ALL<sign # "64", B64>; + } defm CVT_f16 : CVT_FROM_ALL<"f16", B16>; defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>; defm CVT_f32 : CVT_FROM_ALL<"f32", B32>; @@ -571,18 +538,12 @@ let hasSideEffects = false in { // These cvts are different from those above: The source and dest registers // are of the same type. - def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), - "cvt.s16.s8">; - def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "cvt.s32.s8">; - def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "cvt.s32.s16">; - def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s8">; - def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s16">; - def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s32">; + def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), "cvt.s16.s8">; + def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s8">; + def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s16">; + def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s8">; + def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s16">; + def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s32">; multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> { def _f32 : @@ -784,7 +745,7 @@ defm SUB : I3<"sub.s", sub, commutative = false>; def ADD16x2 : I16x2<"add.s", add>; -// in32 and int64 addition and subtraction with carry-out. +// int32 and int64 addition and subtraction with carry-out. defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>; defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>; @@ -805,17 +766,17 @@ defm UDIV : I3<"div.u", udiv, commutative = false>; defm SREM : I3<"rem.s", srem, commutative = false>; defm UREM : I3<"rem.u", urem, commutative = false>; -// Integer absolute value. NumBits should be one minus the bit width of RC. -// This idiom implements the algorithm at -// http://graphics.stanford.edu/~seander/bithacks.html#IntegerAbs. -multiclass ABS<ValueType T, RegisterClass RC, string SizeName> { - def : BasicNVPTXInst<(outs RC:$dst), (ins RC:$a), - "abs" # SizeName, - [(set T:$dst, (abs T:$a))]>; +foreach t = [I16RT, I32RT, I64RT] in { + def ABS_S # t.Size : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a), + "abs.s" # t.Size, + [(set t.Ty:$dst, (abs t.Ty:$a))]>; + + def NEG_S # t.Size : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), + "neg.s" # t.Size, + [(set t.Ty:$dst, (ineg t.Ty:$src))]>; } -defm ABS_16 : ABS<i16, B16, ".s16">; -defm ABS_32 : ABS<i32, B32, ".s32">; -defm ABS_64 : ABS<i64, B64, ".s64">; // Integer min/max. defm SMAX : I3<"max.s", smax, commutative = true>; @@ -832,170 +793,63 @@ def UMIN16x2 : I16x2<"min.u", umin>; // // Wide multiplication // -def MULWIDES64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">; -def MULWIDES64Imm : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">; -def MULWIDES64Imm64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.s32">; - -def MULWIDEU64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">; -def MULWIDEU64Imm : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">; -def MULWIDEU64Imm64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.u32">; - -def MULWIDES32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">; -def MULWIDES32Imm : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">; -def MULWIDES32Imm32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.s16">; - -def MULWIDEU32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">; -def MULWIDEU32Imm : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">; -def MULWIDEU32Imm32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">; - -def SDTMulWide : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>; -def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>; -def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>; - -// Matchers for signed, unsigned mul.wide ISD nodes. -let Predicates = [doMulWide] in { - def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>; - def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>; - def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>; - def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>; - - def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>; - def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>; - def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>; - def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>; -} - -// Predicates used for converting some patterns to mul.wide. -def SInt32Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isSignedIntN(32); -}]>; - -def UInt32Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isIntN(32); -}]>; - -def SInt16Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isSignedIntN(16); -}]>; - -def UInt16Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isIntN(16); -}]>; - -def IntConst_0_30 : PatLeaf<(imm), [{ - // Check if 0 <= v < 31; only then will the result of (x << v) be an int32. - const APInt &v = N->getAPIntValue(); - return v.sge(0) && v.slt(31); -}]>; - -def IntConst_0_14 : PatLeaf<(imm), [{ - // Check if 0 <= v < 15; only then will the result of (x << v) be an int16. - const APInt &v = N->getAPIntValue(); - return v.sge(0) && v.slt(15); -}]>; - -def SHL2MUL32 : SDNodeXForm<imm, [{ - const APInt &v = N->getAPIntValue(); - APInt temp(32, 1); - return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32); -}]>; -def SHL2MUL16 : SDNodeXForm<imm, [{ - const APInt &v = N->getAPIntValue(); - APInt temp(16, 1); - return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16); -}]>; +def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>; +def smul_wide : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>; +def umul_wide : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>; -// Convert "sign/zero-extend, then shift left by an immediate" to mul.wide. -let Predicates = [doMulWide] in { - def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)), - (MULWIDES64Imm $a, (SHL2MUL32 $b))>; - def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)), - (MULWIDEU64Imm $a, (SHL2MUL32 $b))>; - def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)), - (MULWIDES32Imm $a, (SHL2MUL16 $b))>; - def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)), - (MULWIDEU32Imm $a, (SHL2MUL16 $b))>; - - // Convert "sign/zero-extend then multiply" to mul.wide. - def : Pat<(mul (sext i32:$a), (sext i32:$b)), - (MULWIDES64 $a, $b)>; - def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)), - (MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>; - - def : Pat<(mul (zext i32:$a), (zext i32:$b)), - (MULWIDEU64 $a, $b)>; - def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)), - (MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>; - - def : Pat<(mul (sext i16:$a), (sext i16:$b)), - (MULWIDES32 $a, $b)>; - def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)), - (MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>; - - def : Pat<(mul (zext i16:$a), (zext i16:$b)), - (MULWIDEU32 $a, $b)>; - def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)), - (MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>; +multiclass MULWIDEInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> { + def suffix # _rr : + BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.RC:$b), + "mul.wide." # suffix, + [(set big_t.Ty:$dst, (op small_t.Ty:$a, small_t.Ty:$b))]>; + def suffix # _ri : + BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.Imm:$b), + "mul.wide." # suffix, + [(set big_t.Ty:$dst, (op small_t.Ty:$a, imm:$b))]>; } +defm MUL_WIDE : MULWIDEInst<"s32", smul_wide, I64RT, I32RT>; +defm MUL_WIDE : MULWIDEInst<"u32", umul_wide, I64RT, I32RT>; +defm MUL_WIDE : MULWIDEInst<"s16", smul_wide, I32RT, I16RT>; +defm MUL_WIDE : MULWIDEInst<"u16", umul_wide, I32RT, I16RT>; + // // Integer multiply-add // -def mul_oneuse : OneUse2<mul>; - -multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> { +multiclass MADInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> { def rrr: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Reg:$b, Reg:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>; - - def rir: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Imm:$b, Reg:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.RC:$b, big_t.RC:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), big_t.Ty:$c))]>; def rri: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Reg:$b, Imm:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.RC:$b, big_t.Imm:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), imm:$c))]>; + def rir: + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.Imm:$b, big_t.RC:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), big_t.Ty:$c))]>; def rii: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Imm:$b, Imm:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.Imm:$b, big_t.Imm:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), imm:$c))]>; } let Predicates = [hasOptEnabled] in { -defm MAD16 : MAD<"mad.lo.s16", i16, B16, i16imm>; -defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>; -defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>; -} + defm MAD_LO_S16 : MADInst<"lo.s16", mul, I16RT, I16RT>; + defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>; + defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>; -foreach t = [I16RT, I32RT, I64RT] in { - def NEG_S # t.Size : - BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), - "neg.s" # t.Size, - [(set t.Ty:$dst, (ineg t.Ty:$src))]>; + defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>; + defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>; + defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>; + defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>; } //----------------------------------- @@ -1106,8 +960,7 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b), def FRCP32_approx_r : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$b), (ins FTZFlag:$ftz), "rcp.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>; @@ -1116,14 +969,12 @@ def FRCP32_approx_r : // def FDIV32_approx_rr : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>; def FDIV32_approx_ri : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>; // @@ -1146,14 +997,12 @@ def : Pat<(fdiv_full f32imm_1, f32:$b), // def FDIV32rr : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.full$ftz.f32", [(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>; def FDIV32ri : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.full$ftz.f32", [(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>; // @@ -1167,8 +1016,7 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b), def FRCP32r_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$b), (ins FTZFlag:$ftz), "rcp.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>; // @@ -1176,14 +1024,12 @@ def FRCP32r_prec : // def FDIV32rr_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>; def FDIV32ri_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>; @@ -1262,10 +1108,8 @@ def TANH_APPROX_f32 : // Template for three-arg bitwise operations. Takes three args, Creates .b16, // .b32, .b64, and .pred (predicate registers -- i.e., i1) versions of OpcStr. multiclass BITWISE<string OpcStr, SDNode OpNode> { - defm b1 : I3Inst<OpcStr # ".pred", OpNode, I1RT, commutative = true>; - defm b16 : I3Inst<OpcStr # ".b16", OpNode, I16RT, commutative = true>; - defm b32 : I3Inst<OpcStr # ".b32", OpNode, I32RT, commutative = true>; - defm b64 : I3Inst<OpcStr # ".b64", OpNode, I64RT, commutative = true>; + foreach t = [I1RT, I16RT, I32RT, I64RT] in + defm _ # t.PtxType : I3Inst<OpcStr # "." # t.PtxType, OpNode, t, commutative = true>; } defm OR : BITWISE<"or", or>; @@ -1273,48 +1117,40 @@ defm AND : BITWISE<"and", and>; defm XOR : BITWISE<"xor", xor>; // PTX does not support mul on predicates, convert to and instructions -def : Pat<(mul i1:$a, i1:$b), (ANDb1rr $a, $b)>; -def : Pat<(mul i1:$a, imm:$b), (ANDb1ri $a, imm:$b)>; +def : Pat<(mul i1:$a, i1:$b), (AND_predrr $a, $b)>; +def : Pat<(mul i1:$a, imm:$b), (AND_predri $a, imm:$b)>; foreach op = [add, sub] in { - def : Pat<(op i1:$a, i1:$b), (XORb1rr $a, $b)>; - def : Pat<(op i1:$a, imm:$b), (XORb1ri $a, imm:$b)>; + def : Pat<(op i1:$a, i1:$b), (XOR_predrr $a, $b)>; + def : Pat<(op i1:$a, imm:$b), (XOR_predri $a, imm:$b)>; } // These transformations were once reliably performed by instcombine, but thanks // to poison semantics they are no longer safe for LLVM IR, perform them here // instead. -def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr $a, $b)>; -def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr $a, $b)>; +def : Pat<(select i1:$a, i1:$b, 0), (AND_predrr $a, $b)>; +def : Pat<(select i1:$a, 1, i1:$b), (OR_predrr $a, $b)>; // Lower logical v2i16/v4i8 ops as bitwise ops on b32. foreach vt = [v2i16, v4i8] in { - def : Pat<(or vt:$a, vt:$b), (ORb32rr $a, $b)>; - def : Pat<(xor vt:$a, vt:$b), (XORb32rr $a, $b)>; - def : Pat<(and vt:$a, vt:$b), (ANDb32rr $a, $b)>; + def : Pat<(or vt:$a, vt:$b), (OR_b32rr $a, $b)>; + def : Pat<(xor vt:$a, vt:$b), (XOR_b32rr $a, $b)>; + def : Pat<(and vt:$a, vt:$b), (AND_b32rr $a, $b)>; // The constants get legalized into a bitcast from i32, so that's what we need // to match here. def: Pat<(or vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ORb32ri $a, imm:$b)>; + (OR_b32ri $a, imm:$b)>; def: Pat<(xor vt:$a, (vt (bitconvert (i32 imm:$b)))), - (XORb32ri $a, imm:$b)>; + (XOR_b32ri $a, imm:$b)>; def: Pat<(and vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ANDb32ri $a, imm:$b)>; -} - -def NOT1 : BasicNVPTXInst<(outs B1:$dst), (ins B1:$src), - "not.pred", - [(set i1:$dst, (not i1:$src))]>; -def NOT16 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), - "not.b16", - [(set i16:$dst, (not i16:$src))]>; -def NOT32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "not.b32", - [(set i32:$dst, (not i32:$src))]>; -def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "not.b64", - [(set i64:$dst, (not i64:$src))]>; + (AND_b32ri $a, imm:$b)>; +} + +foreach t = [I1RT, I16RT, I32RT, I64RT] in + def NOT_ # t.PtxType : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), + "not." # t.PtxType, + [(set t.Ty:$dst, (not t.Ty:$src))]>; // Template for left/right shifts. Takes three operands, // [dest (reg), src (reg), shift (reg or imm)]. @@ -1322,34 +1158,22 @@ def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), // // This template also defines a 32-bit shift (imm, imm) instruction. multiclass SHIFT<string OpcStr, SDNode OpNode> { - def i64rr : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, B32:$b), - OpcStr # "64", - [(set i64:$dst, (OpNode i64:$a, i32:$b))]>; - def i64ri : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, i32imm:$b), - OpcStr # "64", - [(set i64:$dst, (OpNode i64:$a, (i32 imm:$b)))]>; - def i32rr : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode i32:$a, i32:$b))]>; - def i32ri : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, i32imm:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode i32:$a, (i32 imm:$b)))]>; - def i32ii : - BasicNVPTXInst<(outs B32:$dst), (ins i32imm:$a, i32imm:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode (i32 imm:$a), (i32 imm:$b)))]>; - def i16rr : - BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B32:$b), - OpcStr # "16", - [(set i16:$dst, (OpNode i16:$a, i32:$b))]>; - def i16ri : - BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, i32imm:$b), - OpcStr # "16", - [(set i16:$dst, (OpNode i16:$a, (i32 imm:$b)))]>; + let hasSideEffects = false in { + foreach t = [I64RT, I32RT, I16RT] in { + def t.Size # _rr : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, B32:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode t.Ty:$a, i32:$b))]>; + def t.Size # _ri : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode t.Ty:$a, (i32 imm:$b)))]>; + def t.Size # _ii : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode (t.Ty imm:$a), (i32 imm:$b)))]>; + } + } } defm SHL : SHIFT<"shl.b", shl>; @@ -1357,14 +1181,11 @@ defm SRA : SHIFT<"shr.s", sra>; defm SRL : SHIFT<"shr.u", srl>; // Bit-reverse -def BREV32 : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a), - "brev.b32", - [(set i32:$dst, (bitreverse i32:$a))]>; -def BREV64 : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a), - "brev.b64", - [(set i64:$dst, (bitreverse i64:$a))]>; +foreach t = [I64RT, I32RT] in + def BREV_ # t.PtxType : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a), + "brev." # t.PtxType, + [(set t.Ty:$dst, (bitreverse t.Ty:$a))]>; // @@ -1516,20 +1337,19 @@ def : Pat<(i16 (sext_inreg (trunc (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtN // Byte extraction via shift/trunc/sext -def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), - (CVT_s8_s32 $s, CvtNONE)>; -def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)), +def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), (CVT_s8_s32 $s, CvtNONE)>; +def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), (CVT_s8_s64 $s, CvtNONE)>; + +def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), (BFE_S32rii $s, imm:$o, 8)>; +def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), (BFE_S64rii $s, imm:$o, 8)>; + +def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)), (CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>; -def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), - (BFE_S32rii $s, imm:$o, 8)>; +def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)), + (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>; + def : Pat<(i16 (sra (i16 (trunc i32:$s)), (i32 8))), (CVT_s8_s32 (BFE_S32rii $s, 8, 8), CvtNONE)>; -def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), - (BFE_S64rii $s, imm:$o, 8)>; -def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), - (CVT_s8_s64 $s, CvtNONE)>; -def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)), - (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>; //----------------------------------- // Comparison instructions (setp, set) @@ -1619,10 +1439,7 @@ def SETP_bf16x2rr : def addr : ComplexPattern<pAny, 2, "SelectADDR">; -def ADDR_base : Operand<pAny> { - let PrintMethod = "printOperand"; -} - +def ADDR_base : Operand<pAny>; def ADDR : Operand<pAny> { let PrintMethod = "printMemOperand"; let MIOperandInfo = (ops ADDR_base, i32imm); @@ -1636,10 +1453,6 @@ def MmaCode : Operand<i32> { let PrintMethod = "printMmaCode"; } -def Offseti32imm : Operand<i32> { - let PrintMethod = "printOffseti32imm"; -} - // Get pointer to local stack. let hasSideEffects = false in { def MOV_DEPOT_ADDR : NVPTXInst<(outs B32:$d), (ins i32imm:$num), @@ -1651,33 +1464,31 @@ let hasSideEffects = false in { // copyPhysreg is hard-coded in NVPTXInstrInfo.cpp let hasSideEffects = false, isAsCheapAsAMove = true in { - // Class for register-to-register moves - class MOVr<RegisterClass RC, string OpStr> : - BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), - "mov." # OpStr>; - - // Class for immediate-to-register moves - class MOVi<RegisterClass RC, string OpStr, ValueType VT, Operand IMMType, SDNode ImmNode> : - BasicNVPTXInst<(outs RC:$dst), (ins IMMType:$src), - "mov." # OpStr, - [(set VT:$dst, ImmNode:$src)]>; -} + let isMoveReg = true in + class MOVr<RegisterClass RC, string OpStr> : + BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), "mov." # OpStr>; -def IMOV1r : MOVr<B1, "pred">; -def MOV16r : MOVr<B16, "b16">; -def IMOV32r : MOVr<B32, "b32">; -def IMOV64r : MOVr<B64, "b64">; -def IMOV128r : MOVr<B128, "b128">; + let isMoveImm = true in + class MOVi<RegTyInfo t, string suffix> : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.Imm:$src), + "mov." # suffix, + [(set t.Ty:$dst, t.ImmNode:$src)]>; +} +def MOV_B1_r : MOVr<B1, "pred">; +def MOV_B16_r : MOVr<B16, "b16">; +def MOV_B32_r : MOVr<B32, "b32">; +def MOV_B64_r : MOVr<B64, "b64">; +def MOV_B128_r : MOVr<B128, "b128">; -def IMOV1i : MOVi<B1, "pred", i1, i1imm, imm>; -def IMOV16i : MOVi<B16, "b16", i16, i16imm, imm>; -def IMOV32i : MOVi<B32, "b32", i32, i32imm, imm>; -def IMOV64i : MOVi<B64, "b64", i64, i64imm, imm>; -def FMOV16i : MOVi<B16, "b16", f16, f16imm, fpimm>; -def BFMOV16i : MOVi<B16, "b16", bf16, bf16imm, fpimm>; -def FMOV32i : MOVi<B32, "b32", f32, f32imm, fpimm>; -def FMOV64i : MOVi<B64, "b64", f64, f64imm, fpimm>; +def MOV_B1_i : MOVi<I1RT, "pred">; +def MOV_B16_i : MOVi<I16RT, "b16">; +def MOV_B32_i : MOVi<I32RT, "b32">; +def MOV_B64_i : MOVi<I64RT, "b64">; +def MOV_F16_i : MOVi<F16RT, "b16">; +def MOV_BF16_i : MOVi<BF16RT, "b16">; +def MOV_F32_i : MOVi<F32RT, "b32">; +def MOV_F64_i : MOVi<F64RT, "b64">; def to_tglobaladdr : SDNodeXForm<globaladdr, [{ @@ -1695,11 +1506,11 @@ def to_tframeindex : SDNodeXForm<frameindex, [{ return CurDAG->getTargetFrameIndex(N->getIndex(), N->getValueType(0)); }]>; -def : Pat<(i32 globaladdr:$dst), (IMOV32i (to_tglobaladdr $dst))>; -def : Pat<(i64 globaladdr:$dst), (IMOV64i (to_tglobaladdr $dst))>; +def : Pat<(i32 globaladdr:$dst), (MOV_B32_i (to_tglobaladdr $dst))>; +def : Pat<(i64 globaladdr:$dst), (MOV_B64_i (to_tglobaladdr $dst))>; -def : Pat<(i32 externalsym:$dst), (IMOV32i (to_texternsym $dst))>; -def : Pat<(i64 externalsym:$dst), (IMOV64i (to_texternsym $dst))>; +def : Pat<(i32 externalsym:$dst), (MOV_B32_i (to_texternsym $dst))>; +def : Pat<(i64 externalsym:$dst), (MOV_B64_i (to_texternsym $dst))>; //---- Copy Frame Index ---- def LEA_ADDRi : NVPTXInst<(outs B32:$dst), (ins ADDR:$addr), @@ -1713,45 +1524,34 @@ def : Pat<(i64 frameindex:$fi), (LEA_ADDRi64 (to_tframeindex $fi), 0)>; //----------------------------------- // Comparison and Selection //----------------------------------- +// TODO: These patterns seem very specific and brittle. We should try to find +// a more general solution. def cond_signed : PatLeaf<(cond), [{ return isSignedIntSetCC(N->get()); }]>; -def cond_not_signed : PatLeaf<(cond), [{ - return !isSignedIntSetCC(N->get()); -}]>; +// A 16-bit signed comparison of sign-extended byte extracts can be converted +// to 32-bit comparison if we change the PRMT to sign-extend the extracted +// bytes. +def : Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)), + (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)), + cond_signed:$cc), + (SETP_i32rr (PRMT_B32rii i32:$a, 0, (to_sign_extend_selector $sel_a), PrmtNONE), + (PRMT_B32rii i32:$b, 0, (to_sign_extend_selector $sel_b), PrmtNONE), + (cond2cc $cc))>; + +// A 16-bit comparison of truncated byte extracts can be be converted to 32-bit +// comparison because we know that the truncate is just trancating off zeros +// and that the most-significant byte is also zeros so the meaning of signed and +// unsigned comparisons will not be changed. +def : Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), + (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), + cond:$cc), + (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), + (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), + (cond2cc $cc))>; -// comparisons of i8 extracted with PRMT as i32 -// It's faster to do comparison directly on i32 extracted by PRMT, -// instead of the long conversion and sign extending. -def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), i8)), - (i16 (sext_inreg (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), i8)), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)), - (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), - (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), - (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), - cond_not_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; def SDTDeclareArrayParam : SDTypeProfile<0, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>]>; @@ -1802,8 +1602,6 @@ foreach is_convergent = [0, 1] in { } defvar call_inst = !cast<NVPTXInst>("CALL" # convergent_suffix); - def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, globaladdr:$addr, imm:$proto), - (call_inst (to_tglobaladdr $addr), imm:$rets, imm:$params, imm:$proto)>; def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i32:$addr, imm:$proto), (call_inst $addr, imm:$rets, imm:$params, imm:$proto)>; def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i64:$addr, imm:$proto), @@ -1812,10 +1610,6 @@ foreach is_convergent = [0, 1] in { defvar call_uni_inst = !cast<NVPTXInst>("CALL_UNI" # convergent_suffix); def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, globaladdr:$addr, 0), (call_uni_inst (to_tglobaladdr $addr), imm:$rets, imm:$params)>; - def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i32:$addr, 0), - (call_uni_inst $addr, imm:$rets, imm:$params)>; - def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i64:$addr, 0), - (call_uni_inst $addr, imm:$rets, imm:$params)>; } def DECLARE_PARAM_array : @@ -1830,6 +1624,18 @@ def : Pat<(declare_array_param externalsym:$a, imm:$align, imm:$size), def : Pat<(declare_scalar_param externalsym:$a, imm:$size), (DECLARE_PARAM_scalar (to_texternsym $a), imm:$size)>; +// Call prototype wrapper, this is a dummy instruction that just prints it's +// operand which is string defining the prototype. +def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>; +def CallPrototype : + SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def ProtoIdent : Operand<i32> { let PrintMethod = "printProtoIdent"; } +def CALL_PROTOTYPE : + NVPTXInst<(outs), (ins ProtoIdent:$ident), + "$ident", [(CallPrototype (i32 texternalsym:$ident))]>; + + foreach t = [I32RT, I64RT] in { defvar inst_name = "MOV" # t.Size # "_PARAM"; def inst_name : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), "mov.b" # t.Size>; @@ -1849,6 +1655,32 @@ defm ProxyRegB16 : ProxyRegInst<"b16", B16>; defm ProxyRegB32 : ProxyRegInst<"b32", B32>; defm ProxyRegB64 : ProxyRegInst<"b64", B64>; + +// Callseq start and end + +// Note: these nodes are marked as SDNPMayStore and SDNPMayLoad because +// they define the scope in which the declared params may be used. Therefore +// we add these flags to ensure ld.param and st.param are not sunk or hoisted +// out of that scope. + +def callseq_start : SDNode<"ISD::CALLSEQ_START", + SDCallSeqStart<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>, + [SDNPHasChain, SDNPOutGlue, + SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>; +def callseq_end : SDNode<"ISD::CALLSEQ_END", + SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>, + [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, + SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>; + +def Callseq_Start : + NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), + "\\{ // callseq $amt1, $amt2", + [(callseq_start timm:$amt1, timm:$amt2)]>; +def Callseq_End : + NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), + "\\} // callseq $amt1", + [(callseq_end timm:$amt1, timm:$amt2)]>; + // // Load / Store Handling // @@ -1861,7 +1693,6 @@ class LD<NVPTXRegClass regclass> "\t$dst, [$addr];", []>; let mayLoad=1, hasSideEffects=0 in { - def LD_i8 : LD<B16>; def LD_i16 : LD<B16>; def LD_i32 : LD<B32>; def LD_i64 : LD<B64>; @@ -1877,7 +1708,6 @@ class ST<DAGOperand O> " \t[$addr], $src;", []>; let mayStore=1, hasSideEffects=0 in { - def ST_i8 : ST<RI16>; def ST_i16 : ST<RI16>; def ST_i32 : ST<RI32>; def ST_i64 : ST<RI64>; @@ -1910,7 +1740,6 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> { "[$addr];", []>; } let mayLoad=1, hasSideEffects=0 in { - defm LDV_i8 : LD_VEC<B16>; defm LDV_i16 : LD_VEC<B16>; defm LDV_i32 : LD_VEC<B32, support_v8 = true>; defm LDV_i64 : LD_VEC<B64>; @@ -1944,7 +1773,6 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> { } let mayStore=1, hasSideEffects=0 in { - defm STV_i8 : ST_VEC<RI16>; defm STV_i16 : ST_VEC<RI16>; defm STV_i32 : ST_VEC<RI32, support_v8 = true>; defm STV_i64 : ST_VEC<RI64>; @@ -2114,14 +1942,14 @@ def : Pat<(i64 (anyext i32:$a)), (CVT_u64_u32 $a, CvtNONE)>; // truncate i64 def : Pat<(i32 (trunc i64:$a)), (CVT_u32_u64 $a, CvtNONE)>; def : Pat<(i16 (trunc i64:$a)), (CVT_u16_u64 $a, CvtNONE)>; -def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (ANDb64ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (AND_b64ri $a, 1), 0, CmpNE)>; // truncate i32 def : Pat<(i16 (trunc i32:$a)), (CVT_u16_u32 $a, CvtNONE)>; -def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (ANDb32ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (AND_b32ri $a, 1), 0, CmpNE)>; // truncate i16 -def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (ANDb16ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (AND_b16ri $a, 1), 0, CmpNE)>; // sext_inreg def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>; @@ -2365,52 +2193,20 @@ defm : CVT_ROUND<frint, CvtRNI, CvtRNI_FTZ>; //----------------------------------- let isTerminator=1 in { - let isReturn=1, isBarrier=1 in + let isReturn=1, isBarrier=1 in def Return : BasicNVPTXInst<(outs), (ins), "ret", [(retglue)]>; - let isBranch=1 in - def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), + let isBranch=1 in { + def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), "@$a bra \t$target;", [(brcond i1:$a, bb:$target)]>; - let isBranch=1 in - def CBranchOther : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), - "@!$a bra \t$target;", []>; - let isBranch=1, isBarrier=1 in + let isBarrier=1 in def GOTO : BasicNVPTXInst<(outs), (ins brtarget:$target), - "bra.uni", [(br bb:$target)]>; + "bra.uni", [(br bb:$target)]>; + } } -def : Pat<(brcond i32:$a, bb:$target), - (CBranch (SETP_i32ri $a, 0, CmpNE), bb:$target)>; - -// SelectionDAGBuilder::visitSWitchCase() will invert the condition of a -// conditional branch if the target block is the next block so that the code -// can fall through to the target block. The inversion is done by 'xor -// condition, 1', which will be translated to (setne condition, -1). Since ptx -// supports '@!pred bra target', we should use it. -def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target), - (CBranchOther $a, bb:$target)>; - -// Call -def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>, - SDTCisVT<1, i32>]>; -def SDT_NVPTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; - -def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_NVPTXCallSeqStart, - [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>; -def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_NVPTXCallSeqEnd, - [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, - SDNPSideEffect]>; - -def Callseq_Start : - NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), - "\\{ // callseq $amt1, $amt2", - [(callseq_start timm:$amt1, timm:$amt2)]>; -def Callseq_End : - NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), - "\\} // callseq $amt1", - [(callseq_end timm:$amt1, timm:$amt2)]>; // trap instruction def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>; @@ -2420,18 +2216,6 @@ def trapexitinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>, Requires<[ // brkpt instruction def debugtrapinst : BasicNVPTXInst<(outs), (ins), "brkpt", [(debugtrap)]>; -// Call prototype wrapper -def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>; -def CallPrototype : - SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; -def ProtoIdent : Operand<i32> { - let PrintMethod = "printProtoIdent"; -} -def CALL_PROTOTYPE : - NVPTXInst<(outs), (ins ProtoIdent:$ident), - "$ident", [(CallPrototype (i32 texternalsym:$ident))]>; - def SDTDynAllocaOp : SDTypeProfile<1, 2, [SDTCisSameAs<0, 1>, SDTCisInt<1>, SDTCisVT<2, i32>]>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 0a00220..d337192 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -243,63 +243,82 @@ foreach sync = [false, true] in { } // vote.{all,any,uni,ballot} -multiclass VOTE<NVPTXRegClass regclass, string mode, Intrinsic IntOp> { - def : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred), - "vote." # mode, - [(set regclass:$dest, (IntOp i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; -} +let Predicates = [hasPTX<60>, hasSM<30>] in { + multiclass VOTE<string mode, RegTyInfo t, Intrinsic op> { + def : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred), + "vote." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op i1:$pred))]>; + } -defm VOTE_ALL : VOTE<B1, "all.pred", int_nvvm_vote_all>; -defm VOTE_ANY : VOTE<B1, "any.pred", int_nvvm_vote_any>; -defm VOTE_UNI : VOTE<B1, "uni.pred", int_nvvm_vote_uni>; -defm VOTE_BALLOT : VOTE<B32, "ballot.b32", int_nvvm_vote_ballot>; + defm VOTE_ALL : VOTE<"all", I1RT, int_nvvm_vote_all>; + defm VOTE_ANY : VOTE<"any", I1RT, int_nvvm_vote_any>; + defm VOTE_UNI : VOTE<"uni", I1RT, int_nvvm_vote_uni>; + defm VOTE_BALLOT : VOTE<"ballot", I32RT, int_nvvm_vote_ballot>; + + // vote.sync.{all,any,uni,ballot} + multiclass VOTE_SYNC<string mode, RegTyInfo t, Intrinsic op> { + def i : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, i32imm:$mask), + "vote.sync." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op imm:$mask, i1:$pred))]>; + def r : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, B32:$mask), + "vote.sync." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op i32:$mask, i1:$pred))]>; + } -// vote.sync.{all,any,uni,ballot} -multiclass VOTE_SYNC<NVPTXRegClass regclass, string mode, Intrinsic IntOp> { - def i : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, i32imm:$mask), - "vote.sync." # mode, - [(set regclass:$dest, (IntOp imm:$mask, i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; - def r : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, B32:$mask), - "vote.sync." # mode, - [(set regclass:$dest, (IntOp i32:$mask, i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; + defm VOTE_SYNC_ALL : VOTE_SYNC<"all", I1RT, int_nvvm_vote_all_sync>; + defm VOTE_SYNC_ANY : VOTE_SYNC<"any", I1RT, int_nvvm_vote_any_sync>; + defm VOTE_SYNC_UNI : VOTE_SYNC<"uni", I1RT, int_nvvm_vote_uni_sync>; + defm VOTE_SYNC_BALLOT : VOTE_SYNC<"ballot", I32RT, int_nvvm_vote_ballot_sync>; } - -defm VOTE_SYNC_ALL : VOTE_SYNC<B1, "all.pred", int_nvvm_vote_all_sync>; -defm VOTE_SYNC_ANY : VOTE_SYNC<B1, "any.pred", int_nvvm_vote_any_sync>; -defm VOTE_SYNC_UNI : VOTE_SYNC<B1, "uni.pred", int_nvvm_vote_uni_sync>; -defm VOTE_SYNC_BALLOT : VOTE_SYNC<B32, "ballot.b32", int_nvvm_vote_ballot_sync>; - // elect.sync +let Predicates = [hasPTX<80>, hasSM<90>] in { def INT_ELECT_SYNC_I : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins i32imm:$mask), "elect.sync", - [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>, - Requires<[hasPTX<80>, hasSM<90>]>; + [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>; def INT_ELECT_SYNC_R : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins B32:$mask), "elect.sync", - [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>, - Requires<[hasPTX<80>, hasSM<90>]>; + [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>; +} + +let Predicates = [hasPTX<60>, hasSM<70>] in { + multiclass MATCH_ANY_SYNC<Intrinsic op, RegTyInfo t> { + def ii : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, i32imm:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op imm:$mask, imm:$value))]>; + def ir : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, B32:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op i32:$mask, imm:$value))]>; + def ri : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, i32imm:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op imm:$mask, t.Ty:$value))]>; + def rr : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, B32:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op i32:$mask, t.Ty:$value))]>; + } -multiclass MATCH_ANY_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp, - Operand ImmOp> { - def ii : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, i32imm:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp imm:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ir : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, B32:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp i32:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ri : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, i32imm:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp imm:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def rr : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, B32:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp i32:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; + defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i32, I32RT>; + defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i64, I64RT>; + + multiclass MATCH_ALLP_SYNC<RegTyInfo t, Intrinsic op> { + def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.Imm:$value, i32imm:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op imm:$mask, imm:$value))]>; + def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.Imm:$value, B32:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op i32:$mask, imm:$value))]>; + def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.RC:$value, i32imm:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op imm:$mask, t.Ty:$value))]>; + def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.RC:$value, B32:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op i32:$mask, t.Ty:$value))]>; + } + defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<I32RT, int_nvvm_match_all_sync_i32p>; + defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<I64RT, int_nvvm_match_all_sync_i64p>; } // activemask.b32 @@ -308,39 +327,6 @@ def ACTIVEMASK : BasicNVPTXInst<(outs B32:$dest), (ins), [(set i32:$dest, (int_nvvm_activemask))]>, Requires<[hasPTX<62>, hasSM<30>]>; -defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<B32, "b32", int_nvvm_match_any_sync_i32, - i32imm>; -defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<B64, "b64", int_nvvm_match_any_sync_i64, - i64imm>; - -multiclass MATCH_ALLP_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp, - Operand ImmOp> { - def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins ImmOp:$value, i32imm:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp imm:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins ImmOp:$value, B32:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp i32:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins regclass:$value, i32imm:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp imm:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins regclass:$value, B32:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp i32:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; -} -defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<B32, "b32", int_nvvm_match_all_sync_i32p, - i32imm>; -defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<B64, "b64", int_nvvm_match_all_sync_i64p, - i64imm>; - multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> { def : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src, B32:$mask), "redux.sync." # BinOp # "." # PTXType, @@ -381,24 +367,20 @@ defm REDUX_SYNC_FMAX_ABS_NAN: REDUX_SYNC_F<"max", ".abs", ".NaN">; //----------------------------------- // Explicit Memory Fence Functions //----------------------------------- -class MEMBAR<string StrOp, Intrinsic IntOP> : - BasicNVPTXInst<(outs), (ins), - StrOp, [(IntOP)]>; +class NullaryInst<string StrOp, Intrinsic IntOP> : + BasicNVPTXInst<(outs), (ins), StrOp, [(IntOP)]>; -def INT_MEMBAR_CTA : MEMBAR<"membar.cta", int_nvvm_membar_cta>; -def INT_MEMBAR_GL : MEMBAR<"membar.gl", int_nvvm_membar_gl>; -def INT_MEMBAR_SYS : MEMBAR<"membar.sys", int_nvvm_membar_sys>; +def INT_MEMBAR_CTA : NullaryInst<"membar.cta", int_nvvm_membar_cta>; +def INT_MEMBAR_GL : NullaryInst<"membar.gl", int_nvvm_membar_gl>; +def INT_MEMBAR_SYS : NullaryInst<"membar.sys", int_nvvm_membar_sys>; def INT_FENCE_SC_CLUSTER: - MEMBAR<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, + NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, Requires<[hasPTX<78>, hasSM<90>]>; // Proxy fence (uni-directional) -// fence.proxy.tensormap.release variants - class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE<string Scope, Intrinsic Intr> : - BasicNVPTXInst<(outs), (ins), - "fence.proxy.tensormap::generic.release." # Scope, [(Intr)]>, + NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>, Requires<[hasPTX<83>, hasSM<90>]>; def INT_FENCE_PROXY_TENSORMAP_GENERIC_RELEASE_CTA: @@ -488,35 +470,31 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 : CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16, int_nvvm_cp_async_cg_shared_global_16_s>; -def CP_ASYNC_COMMIT_GROUP : - BasicNVPTXInst<(outs), (ins), "cp.async.commit_group", [(int_nvvm_cp_async_commit_group)]>, - Requires<[hasPTX<70>, hasSM<80>]>; +let Predicates = [hasPTX<70>, hasSM<80>] in { + def CP_ASYNC_COMMIT_GROUP : + NullaryInst<"cp.async.commit_group", int_nvvm_cp_async_commit_group>; -def CP_ASYNC_WAIT_GROUP : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group", - [(int_nvvm_cp_async_wait_group timm:$n)]>, - Requires<[hasPTX<70>, hasSM<80>]>; + def CP_ASYNC_WAIT_GROUP : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group", + [(int_nvvm_cp_async_wait_group timm:$n)]>; -def CP_ASYNC_WAIT_ALL : - BasicNVPTXInst<(outs), (ins), "cp.async.wait_all", - [(int_nvvm_cp_async_wait_all)]>, - Requires<[hasPTX<70>, hasSM<80>]>; + def CP_ASYNC_WAIT_ALL : + NullaryInst<"cp.async.wait_all", int_nvvm_cp_async_wait_all>; +} -// cp.async.bulk variants of the commit/wait group -def CP_ASYNC_BULK_COMMIT_GROUP : - BasicNVPTXInst<(outs), (ins), "cp.async.bulk.commit_group", - [(int_nvvm_cp_async_bulk_commit_group)]>, - Requires<[hasPTX<80>, hasSM<90>]>; +let Predicates = [hasPTX<80>, hasSM<90>] in { + // cp.async.bulk variants of the commit/wait group + def CP_ASYNC_BULK_COMMIT_GROUP : + NullaryInst<"cp.async.bulk.commit_group", int_nvvm_cp_async_bulk_commit_group>; -def CP_ASYNC_BULK_WAIT_GROUP : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group", - [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def CP_ASYNC_BULK_WAIT_GROUP : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group", + [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>; -def CP_ASYNC_BULK_WAIT_GROUP_READ : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read", - [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def CP_ASYNC_BULK_WAIT_GROUP_READ : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read", + [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>; +} //------------------------------ // TMA Async Bulk Copy Functions @@ -974,33 +952,30 @@ defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4", //Prefetch and Prefetchu -class PREFETCH_INTRS<string InstName> : - BasicNVPTXInst<(outs), (ins ADDR:$addr), - InstName, - [(!cast<Intrinsic>(!strconcat("int_nvvm_", - !subst(".", "_", InstName))) addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; - +let Predicates = [hasPTX<80>, hasSM<90>] in { + class PREFETCH_INTRS<string InstName> : + BasicNVPTXInst<(outs), (ins ADDR:$addr), + InstName, + [(!cast<Intrinsic>(!strconcat("int_nvvm_", + !subst(".", "_", InstName))) addr:$addr)]>; -def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">; -def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">; -def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">; -def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">; -def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">; -def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">; + def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">; + def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">; + def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">; + def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">; + def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">; + def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">; -def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "prefetch.global.L2::evict_normal", - [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr), + "prefetch.global.L2::evict_normal", + [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>; -def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "prefetch.global.L2::evict_last", - [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr), + "prefetch.global.L2::evict_last", + [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>; - -def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">; + def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">; +} //Applypriority intrinsics class APPLYPRIORITY_L2_INTRS<string addrspace> : @@ -1031,99 +1006,82 @@ def DISCARD_GLOBAL_L2 : DISCARD_L2_INTRS<"global">; // MBarrier Functions //----------------------------------- -multiclass MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count), - "mbarrier.init" # AddrSpace # ".b64", - [(Intrin addr:$addr, i32:$count)]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; -defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", - int_nvvm_mbarrier_init_shared>; - -multiclass MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "mbarrier.inval" # AddrSpace # ".b64", - [(Intrin addr:$addr)]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; -defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", - int_nvvm_mbarrier_inval_shared>; - -multiclass MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), - "mbarrier.arrive" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; -defm MBARRIER_ARRIVE_SHARED : - MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; - -multiclass MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), - (ins ADDR:$addr, B32:$count), - "mbarrier.arrive.noComplete" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr, i32:$count))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_NOCOMPLETE : - MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; -defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED : - MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; - -multiclass MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), - "mbarrier.arrive_drop" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_DROP : - MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; -defm MBARRIER_ARRIVE_DROP_SHARED : - MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; - -multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), - (ins ADDR:$addr, B32:$count), - "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr, i32:$count))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_DROP_NOCOMPLETE : - MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; -defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : - MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", - int_nvvm_mbarrier_arrive_drop_noComplete_shared>; - -multiclass MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state), - "mbarrier.test_wait" # AddrSpace # ".b64", - [(set i1:$res, (Intrin addr:$addr, i64:$state))]>, - Requires<[hasPTX<70>, hasSM<80>]>; +let Predicates = [hasPTX<70>, hasSM<80>] in { + class MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count), + "mbarrier.init" # AddrSpace # ".b64", + [(Intrin addr:$addr, i32:$count)]>; + + def MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; + def MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", + int_nvvm_mbarrier_init_shared>; + + class MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs), (ins ADDR:$addr), + "mbarrier.inval" # AddrSpace # ".b64", + [(Intrin addr:$addr)]>; + + def MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; + def MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", + int_nvvm_mbarrier_inval_shared>; + + class MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), + "mbarrier.arrive" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr))]>; + + def MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; + def MBARRIER_ARRIVE_SHARED : + MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; + + class MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), + (ins ADDR:$addr, B32:$count), + "mbarrier.arrive.noComplete" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr, i32:$count))]>; + + def MBARRIER_ARRIVE_NOCOMPLETE : + MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; + def MBARRIER_ARRIVE_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; + + class MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), + "mbarrier.arrive_drop" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr))]>; + + def MBARRIER_ARRIVE_DROP : + MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; + def MBARRIER_ARRIVE_DROP_SHARED : + MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; + + class MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), + (ins ADDR:$addr, B32:$count), + "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr, i32:$count))]>; + + def MBARRIER_ARRIVE_DROP_NOCOMPLETE : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; + def MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", + int_nvvm_mbarrier_arrive_drop_noComplete_shared>; + + class MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state), + "mbarrier.test_wait" # AddrSpace # ".b64", + [(set i1:$res, (Intrin addr:$addr, i64:$state))]>; + + def MBARRIER_TEST_WAIT : + MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; + def MBARRIER_TEST_WAIT_SHARED : + MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; + + def MBARRIER_PENDING_COUNT : + BasicNVPTXInst<(outs B32:$res), (ins B64:$state), + "mbarrier.pending_count.b64", + [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>; } - -defm MBARRIER_TEST_WAIT : - MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; -defm MBARRIER_TEST_WAIT_SHARED : - MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; - -class MBARRIER_PENDING_COUNT<Intrinsic Intrin> : - BasicNVPTXInst<(outs B32:$res), (ins B64:$state), - "mbarrier.pending_count.b64", - [(set i32:$res, (Intrin i64:$state))]>, - Requires<[hasPTX<70>, hasSM<80>]>; - -def MBARRIER_PENDING_COUNT : - MBARRIER_PENDING_COUNT<int_nvvm_mbarrier_pending_count>; - //----------------------------------- // Math Functions //----------------------------------- @@ -1449,15 +1407,11 @@ defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>; def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>; -def COPYSIGN_F : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$src0, B32:$src1), - "copysign.f32", - [(set f32:$dst, (fcopysign_nvptx f32:$src1, f32:$src0))]>; - -def COPYSIGN_D : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$src0, B64:$src1), - "copysign.f64", - [(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>; +foreach t = [F32RT, F64RT] in + def COPYSIGN_ # t : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src0, t.RC:$src1), + "copysign." # t.PtxType, + [(set t.Ty:$dst, (fcopysign_nvptx t.Ty:$src1, t.Ty:$src0))]>; // // Neg bf16, bf16x2 @@ -2255,38 +2209,35 @@ defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">; // Scalar -class LDU_G<string TyStr, NVPTXRegClass regclass> - : NVPTXInst<(outs regclass:$result), (ins ADDR:$src), - "ldu.global." # TyStr # " \t$result, [$src];", []>; +class LDU_G<NVPTXRegClass regclass> + : NVPTXInst<(outs regclass:$result), (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.b$fromWidth \t$result, [$src];", []>; -def LDU_GLOBAL_i8 : LDU_G<"b8", B16>; -def LDU_GLOBAL_i16 : LDU_G<"b16", B16>; -def LDU_GLOBAL_i32 : LDU_G<"b32", B32>; -def LDU_GLOBAL_i64 : LDU_G<"b64", B64>; +def LDU_GLOBAL_i16 : LDU_G<B16>; +def LDU_GLOBAL_i32 : LDU_G<B32>; +def LDU_GLOBAL_i64 : LDU_G<B64>; // vector // Elementized vector ldu -class VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> +class VLDU_G_ELE_V2<NVPTXRegClass regclass> : NVPTXInst<(outs regclass:$dst1, regclass:$dst2), - (ins ADDR:$src), - "ldu.global.v2." # TyStr # " \t{{$dst1, $dst2}}, [$src];", []>; + (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.v2.b$fromWidth \t{{$dst1, $dst2}}, [$src];", []>; -class VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> - : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins ADDR:$src), - "ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>; +class VLDU_G_ELE_V4<NVPTXRegClass regclass> + : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), + (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.v4.b$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>; -def LDU_GLOBAL_v2i8 : VLDU_G_ELE_V2<"b8", B16>; -def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<"b16", B16>; -def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<"b32", B32>; -def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<"b64", B64>; +def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<B16>; +def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<B32>; +def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<B64>; -def LDU_GLOBAL_v4i8 : VLDU_G_ELE_V4<"b8", B16>; -def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<"b16", B16>; -def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<"b32", B32>; +def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<B16>; +def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>; //----------------------------------- @@ -2327,12 +2278,10 @@ class VLDG_G_ELE_V8<NVPTXRegClass regclass> : "ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>; // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads. -def LD_GLOBAL_NC_v2i8 : VLDG_G_ELE_V2<B16>; def LD_GLOBAL_NC_v2i16 : VLDG_G_ELE_V2<B16>; def LD_GLOBAL_NC_v2i32 : VLDG_G_ELE_V2<B32>; def LD_GLOBAL_NC_v2i64 : VLDG_G_ELE_V2<B64>; -def LD_GLOBAL_NC_v4i8 : VLDG_G_ELE_V4<B16>; def LD_GLOBAL_NC_v4i16 : VLDG_G_ELE_V4<B16>; def LD_GLOBAL_NC_v4i32 : VLDG_G_ELE_V4<B32>; @@ -2342,19 +2291,19 @@ def LD_GLOBAL_NC_v8i32 : VLDG_G_ELE_V8<B32>; multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> { if Supports32 then def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src), - "cvta." # Str # ".u32", []>, Requires<Preds>; + "cvta." # Str # ".u32">, Requires<Preds>; def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src), - "cvta." # Str # ".u64", []>, Requires<Preds>; + "cvta." # Str # ".u64">, Requires<Preds>; } multiclass G_TO_NG<string Str, bit Supports32 = 1, list<Predicate> Preds = []> { if Supports32 then def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src), - "cvta.to." # Str # ".u32", []>, Requires<Preds>; + "cvta.to." # Str # ".u32">, Requires<Preds>; def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src), - "cvta.to." # Str # ".u64", []>, Requires<Preds>; + "cvta.to." # Str # ".u64">, Requires<Preds>; } foreach space = ["local", "shared", "global", "const", "param"] in { @@ -4614,9 +4563,9 @@ def INT_PTX_SREG_LANEMASK_GT : PTX_READ_SREG_R32<"lanemask_gt", int_nvvm_read_ptx_sreg_lanemask_gt>; let hasSideEffects = 1 in { -def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; -def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; -def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; + def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; + def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; + def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; } def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>; @@ -5096,37 +5045,36 @@ foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in def : MMA_PAT<mma>; multiclass MAPA<string suffix, Intrinsic Intr> { - def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b), - "mapa" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a, i32:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b), - "mapa" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a, imm:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b), - "mapa" # suffix # ".u64", - [(set i64:$d, (Intr i64:$a, i32:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b), - "mapa" # suffix # ".u64", - [(set i64:$d, (Intr i64:$a, imm:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; + let Predicates = [hasSM<90>, hasPTX<78>] in { + def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b), + "mapa" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a, i32:$b))]>; + def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b), + "mapa" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a, imm:$b))]>; + def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b), + "mapa" # suffix # ".u64", + [(set i64:$d, (Intr i64:$a, i32:$b))]>; + def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b), + "mapa" # suffix # ".u64", + [(set i64:$d, (Intr i64:$a, imm:$b))]>; + } } + defm mapa : MAPA<"", int_nvvm_mapa>; defm mapa_shared_cluster : MAPA<".shared::cluster", int_nvvm_mapa_shared_cluster>; multiclass GETCTARANK<string suffix, Intrinsic Intr> { - def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a), - "getctarank" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a), - "getctarank" # suffix # ".u64", - [(set i32:$d, (Intr i64:$a))]>, - Requires<[hasSM<90>, hasPTX<78>]>; + let Predicates = [hasSM<90>, hasPTX<78>] in { + def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a), + "getctarank" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a))]>; + def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a), + "getctarank" # suffix # ".u64", + [(set i32:$d, (Intr i64:$a))]>; + } } defm getctarank : GETCTARANK<"", int_nvvm_getctarank>; @@ -5165,29 +5113,25 @@ def INT_NVVM_WGMMA_WAIT_GROUP_SYNC_ALIGNED : BasicNVPTXInst<(outs), (ins i64imm: [(int_nvvm_wgmma_wait_group_sync_aligned timm:$n)]>, Requires<[hasSM90a, hasPTX<80>]>; } // isConvergent = true -def GRIDDEPCONTROL_LAUNCH_DEPENDENTS : - BasicNVPTXInst<(outs), (ins), - "griddepcontrol.launch_dependents", - [(int_nvvm_griddepcontrol_launch_dependents)]>, - Requires<[hasSM<90>, hasPTX<78>]>; - -def GRIDDEPCONTROL_WAIT : - BasicNVPTXInst<(outs), (ins), - "griddepcontrol.wait", - [(int_nvvm_griddepcontrol_wait)]>, - Requires<[hasSM<90>, hasPTX<78>]>; +let Predicates = [hasSM<90>, hasPTX<78>] in { + def GRIDDEPCONTROL_LAUNCH_DEPENDENTS : + BasicNVPTXInst<(outs), (ins), "griddepcontrol.launch_dependents", + [(int_nvvm_griddepcontrol_launch_dependents)]>; + def GRIDDEPCONTROL_WAIT : + BasicNVPTXInst<(outs), (ins), "griddepcontrol.wait", + [(int_nvvm_griddepcontrol_wait)]>; +} def INT_EXIT : BasicNVPTXInst<(outs), (ins), "exit", [(int_nvvm_exit)]>; // Tcgen05 intrinsics -let isConvergent = true in { +let isConvergent = true, Predicates = [hasTcgen05Instructions] in { multiclass TCGEN05_ALLOC_INTR<string AS, string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$dst, B32:$ncols), "tcgen05.alloc.cta_group::" # num # ".sync.aligned" # AS # ".b32", - [(Intr addr:$dst, B32:$ncols)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$dst, B32:$ncols)]>; } defm TCGEN05_ALLOC_CG1 : TCGEN05_ALLOC_INTR<"", "1", int_nvvm_tcgen05_alloc_cg1>; @@ -5200,8 +5144,7 @@ multiclass TCGEN05_DEALLOC_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins B32:$tmem_addr, B32:$ncols), "tcgen05.dealloc.cta_group::" # num # ".sync.aligned.b32", - [(Intr B32:$tmem_addr, B32:$ncols)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr B32:$tmem_addr, B32:$ncols)]>; } defm TCGEN05_DEALLOC_CG1: TCGEN05_DEALLOC_INTR<"1", int_nvvm_tcgen05_dealloc_cg1>; defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2>; @@ -5209,19 +5152,13 @@ defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2 multiclass TCGEN05_RELINQ_PERMIT_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins), "tcgen05.relinquish_alloc_permit.cta_group::" # num # ".sync.aligned", - [(Intr)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr)]>; } defm TCGEN05_RELINQ_CG1: TCGEN05_RELINQ_PERMIT_INTR<"1", int_nvvm_tcgen05_relinq_alloc_permit_cg1>; defm TCGEN05_RELINQ_CG2: TCGEN05_RELINQ_PERMIT_INTR<"2", int_nvvm_tcgen05_relinq_alloc_permit_cg2>; -def tcgen05_wait_ld: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::ld.sync.aligned", - [(int_nvvm_tcgen05_wait_ld)]>, - Requires<[hasTcgen05Instructions]>; - -def tcgen05_wait_st: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::st.sync.aligned", - [(int_nvvm_tcgen05_wait_st)]>, - Requires<[hasTcgen05Instructions]>; +def tcgen05_wait_ld: NullaryInst<"tcgen05.wait::ld.sync.aligned", int_nvvm_tcgen05_wait_ld>; +def tcgen05_wait_st: NullaryInst<"tcgen05.wait::st.sync.aligned", int_nvvm_tcgen05_wait_st>; multiclass TCGEN05_COMMIT_INTR<string AS, string num> { defvar prefix = "tcgen05.commit.cta_group::" # num #".mbarrier::arrive::one.shared::cluster"; @@ -5232,12 +5169,10 @@ multiclass TCGEN05_COMMIT_INTR<string AS, string num> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$mbar), prefix # ".b64", - [(Intr addr:$mbar)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$mbar)]>; def _MC : BasicNVPTXInst<(outs), (ins ADDR:$mbar, B16:$mc), prefix # ".multicast::cluster.b64", - [(IntrMC addr:$mbar, B16:$mc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrMC addr:$mbar, B16:$mc)]>; } defm TCGEN05_COMMIT_CG1 : TCGEN05_COMMIT_INTR<"", "1">; @@ -5249,8 +5184,7 @@ multiclass TCGEN05_SHIFT_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr), "tcgen05.shift.cta_group::" # num # ".down", - [(Intr addr:$tmem_addr)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$tmem_addr)]>; } defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>; defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>; @@ -5270,13 +5204,11 @@ multiclass TCGEN05_CP_INTR<string shape, string src_fmt, string mc = ""> { def _cg1 : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr, B64:$sdesc), "tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm, - [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>; def _cg2 : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr, B64:$sdesc), "tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm, - [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>; } foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in { @@ -5289,17 +5221,13 @@ foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in { } } // isConvergent -let hasSideEffects = 1 in { +let hasSideEffects = 1, Predicates = [hasTcgen05Instructions] in { -def tcgen05_fence_before_thread_sync: BasicNVPTXInst<(outs), (ins), - "tcgen05.fence::before_thread_sync", - [(int_nvvm_tcgen05_fence_before_thread_sync)]>, - Requires<[hasTcgen05Instructions]>; + def tcgen05_fence_before_thread_sync: NullaryInst< + "tcgen05.fence::before_thread_sync", int_nvvm_tcgen05_fence_before_thread_sync>; -def tcgen05_fence_after_thread_sync: BasicNVPTXInst<(outs), (ins), - "tcgen05.fence::after_thread_sync", - [(int_nvvm_tcgen05_fence_after_thread_sync)]>, - Requires<[hasTcgen05Instructions]>; + def tcgen05_fence_after_thread_sync: NullaryInst< + "tcgen05.fence::after_thread_sync", int_nvvm_tcgen05_fence_after_thread_sync>; } // hasSideEffects @@ -5392,17 +5320,17 @@ foreach shape = ["16x64b", "16x128b", "16x256b", "32x32b", "16x32bx2"] in { // Bulk store instructions def st_bulk_imm : TImmLeaf<i64, [{ return Imm == 0; }]>; -def INT_NVVM_ST_BULK_GENERIC : - BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), - "st.bulk", - [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>, - Requires<[hasSM<100>, hasPTX<86>]>; +let Predicates = [hasSM<100>, hasPTX<86>] in { + def INT_NVVM_ST_BULK_GENERIC : + BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), + "st.bulk", + [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>; -def INT_NVVM_ST_BULK_SHARED_CTA: - BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), - "st.bulk.shared::cta", - [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>, - Requires<[hasSM<100>, hasPTX<86>]>; + def INT_NVVM_ST_BULK_SHARED_CTA: + BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), + "st.bulk.shared::cta", + [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>; +} // // clusterlaunchcontorl Instructions diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td index d40886a..2e81ab1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -38,14 +38,6 @@ foreach i = 0...4 in { def R#i : NVPTXReg<"%r"#i>; // 32-bit def RL#i : NVPTXReg<"%rd"#i>; // 64-bit def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit - def H#i : NVPTXReg<"%h"#i>; // 16-bit float - def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float - - // Arguments - def ia#i : NVPTXReg<"%ia"#i>; - def la#i : NVPTXReg<"%la"#i>; - def fa#i : NVPTXReg<"%fa"#i>; - def da#i : NVPTXReg<"%da"#i>; } foreach i = 0...31 in { diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCAsmBackend.cpp b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCAsmBackend.cpp index 0e8828f..ec97e2e 100644 --- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCAsmBackend.cpp +++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCAsmBackend.cpp @@ -93,8 +93,8 @@ public: MCFixupKindInfo getFixupKindInfo(MCFixupKind Kind) const override; void applyFixup(const MCFragment &, const MCFixup &Fixup, - const MCValue &Target, MutableArrayRef<char> Data, - uint64_t Value, bool IsResolved) override; + const MCValue &Target, uint8_t *Data, uint64_t Value, + bool IsResolved) override; bool shouldForceRelocation(const MCFixup &Fixup, const MCValue &Target) { // If there is a @ specifier, unless it is optimized out (e.g. constant @l), @@ -185,9 +185,8 @@ MCFixupKindInfo PPCAsmBackend::getFixupKindInfo(MCFixupKind Kind) const { } void PPCAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &TargetVal, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &TargetVal, uint8_t *Data, + uint64_t Value, bool IsResolved) { // In PPC64 ELFv1, .quad .TOC.@tocbase in the .opd section is expected to // reference the null symbol. auto Target = TargetVal; @@ -205,7 +204,6 @@ void PPCAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, if (!Value) return; // Doesn't change encoding. - unsigned Offset = Fixup.getOffset(); unsigned NumBytes = getFixupKindNumBytes(Kind); // For each byte of the fragment that the fixup touches, mask in the bits @@ -213,7 +211,7 @@ void PPCAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // bitfields above. for (unsigned i = 0; i != NumBytes; ++i) { unsigned Idx = Endian == llvm::endianness::little ? i : (NumBytes - 1 - i); - Data[Offset + i] |= uint8_t((Value >> (Idx * 8)) & 0xff); + Data[i] |= uint8_t((Value >> (Idx * 8)) & 0xff); } } diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp index 459525e..f179873 100644 --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -7296,9 +7296,17 @@ SDValue PPCTargetLowering::LowerFormalArguments_AIX( if (!ArgVT.isVector() && !ValVT.isVector() && ArgVT.isInteger() && ValVT.isInteger() && ArgVT.getScalarSizeInBits() < ValVT.getScalarSizeInBits()) { - SDValue ArgValueTrunc = DAG.getNode( - ISD::TRUNCATE, dl, ArgVT.getSimpleVT() == MVT::i1 ? MVT::i8 : ArgVT, - ArgValue); + // It is possible to have either real integer values + // or integers that were not originally integers. + // In the latter case, these could have came from structs, + // and these integers would not have an extend on the parameter. + // Since these types of integers do not have an extend specified + // in the first place, the type of extend that we do should not matter. + EVT TruncatedArgVT = ArgVT.isSimple() && ArgVT.getSimpleVT() == MVT::i1 + ? MVT::i8 + : ArgVT; + SDValue ArgValueTrunc = + DAG.getNode(ISD::TRUNCATE, dl, TruncatedArgVT, ArgValue); SDValue ArgValueExt = ArgSignExt ? DAG.getSExtOrTrunc(ArgValueTrunc, dl, ValVT) : DAG.getZExtOrTrunc(ArgValueTrunc, dl, ValVT); diff --git a/llvm/lib/Target/PowerPC/PPCMachineScheduler.cpp b/llvm/lib/Target/PowerPC/PPCMachineScheduler.cpp index 5eb1f01..b7e2263 100644 --- a/llvm/lib/Target/PowerPC/PPCMachineScheduler.cpp +++ b/llvm/lib/Target/PowerPC/PPCMachineScheduler.cpp @@ -100,10 +100,14 @@ bool PPCPreRASchedStrategy::tryCandidate(SchedCandidate &Cand, // This is a best effort to set things up for a post-RA pass. Optimizations // like generating loads of multiple registers should ideally be done within // the scheduler pass by combining the loads during DAG postprocessing. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; @@ -189,10 +193,14 @@ bool PPCPostRASchedStrategy::tryCandidate(SchedCandidate &Cand, return TryCand.Reason != NoCand; // Keep clustered nodes together. - const ClusterInfo *CandCluster = Cand.AtTop ? TopCluster : BotCluster; - const ClusterInfo *TryCandCluster = TryCand.AtTop ? TopCluster : BotCluster; - if (tryGreater(TryCandCluster && TryCandCluster->contains(TryCand.SU), - CandCluster && CandCluster->contains(Cand.SU), TryCand, Cand, + unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID; + unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID; + bool CandIsClusterSucc = + isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx); + bool TryCandIsClusterSucc = + isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx); + + if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand, Cluster)) return TryCand.Reason != NoCand; diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp index 82e3b5c..eb7460e 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp @@ -327,19 +327,19 @@ bool RISCVAsmBackend::relaxAlign(MCFragment &F, unsigned &Size) { bool RISCVAsmBackend::relaxDwarfLineAddr(MCFragment &F, bool &WasRelaxed) const { - MCContext &C = getContext(); - int64_t LineDelta = F.getDwarfLineDelta(); const MCExpr &AddrDelta = F.getDwarfAddrDelta(); - SmallVector<MCFixup, 1> Fixups; size_t OldSize = F.getVarSize(); int64_t Value; + // If the label difference can be resolved, use the default handling, which + // utilizes a shorter special opcode. + if (AddrDelta.evaluateAsAbsolute(Value, *Asm)) + return false; [[maybe_unused]] bool IsAbsolute = AddrDelta.evaluateKnownAbsolute(Value, *Asm); assert(IsAbsolute && "CFA with invalid expression"); - Fixups.clear(); SmallVector<char> Data; raw_svector_ostream OS(Data); @@ -349,33 +349,21 @@ bool RISCVAsmBackend::relaxDwarfLineAddr(MCFragment &F, encodeSLEB128(LineDelta, OS); } - unsigned Offset; - std::pair<MCFixupKind, MCFixupKind> Fixup; - // According to the DWARF specification, the `DW_LNS_fixed_advance_pc` opcode // takes a single unsigned half (unencoded) operand. The maximum encodable // value is therefore 65535. Set a conservative upper bound for relaxation. + unsigned PCBytes; if (Value > 60000) { - unsigned PtrSize = C.getAsmInfo()->getCodePointerSize(); - - OS << uint8_t(dwarf::DW_LNS_extended_op); - encodeULEB128(PtrSize + 1, OS); - - OS << uint8_t(dwarf::DW_LNE_set_address); - Offset = OS.tell(); - assert((PtrSize == 4 || PtrSize == 8) && "Unexpected pointer size"); - Fixup = RISCV::getRelocPairForSize(PtrSize); - OS.write_zeros(PtrSize); + PCBytes = getContext().getAsmInfo()->getCodePointerSize(); + OS << uint8_t(dwarf::DW_LNS_extended_op) << uint8_t(PCBytes + 1) + << uint8_t(dwarf::DW_LNE_set_address); + OS.write_zeros(PCBytes); } else { + PCBytes = 2; OS << uint8_t(dwarf::DW_LNS_fixed_advance_pc); - Offset = OS.tell(); - Fixup = RISCV::getRelocPairForSize(2); support::endian::write<uint16_t>(OS, 0, llvm::endianness::little); } - - const MCBinaryExpr &MBE = cast<MCBinaryExpr>(AddrDelta); - Fixups.push_back(MCFixup::create(Offset, MBE.getLHS(), std::get<0>(Fixup))); - Fixups.push_back(MCFixup::create(Offset, MBE.getRHS(), std::get<1>(Fixup))); + auto Offset = OS.tell() - PCBytes; if (LineDelta == INT64_MAX) { OS << uint8_t(dwarf::DW_LNS_extended_op); @@ -386,7 +374,8 @@ bool RISCVAsmBackend::relaxDwarfLineAddr(MCFragment &F, } F.setVarContents(Data); - F.setVarFixups(Fixups); + F.setVarFixups({MCFixup::create(Offset, &AddrDelta, + MCFixup::getDataKindForSize(PCBytes))}); WasRelaxed = OldSize != Data.size(); return true; } @@ -881,9 +870,8 @@ bool RISCVAsmBackend::addReloc(const MCFragment &F, const MCFixup &Fixup, } void RISCVAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { IsResolved = addReloc(F, Fixup, Target, Value, IsResolved); MCFixupKind Kind = Fixup.getKind(); if (mc::isRelocation(Kind)) @@ -898,15 +886,14 @@ void RISCVAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, // Shift the value into position. Value <<= Info.TargetOffset; - unsigned Offset = Fixup.getOffset(); unsigned NumBytes = alignTo(Info.TargetSize + Info.TargetOffset, 8) / 8; - - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. for (unsigned i = 0; i != NumBytes; ++i) { - Data[Offset + i] |= uint8_t((Value >> (i * 8)) & 0xff); + Data[i] |= uint8_t((Value >> (i * 8)) & 0xff); } } diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h index d97d632..adec1ec 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h @@ -46,8 +46,7 @@ public: void maybeAddVendorReloc(const MCFragment &, const MCFixup &); void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; std::unique_ptr<MCObjectTargetWriter> createObjectTargetWriter() const override; diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVFixupKinds.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVFixupKinds.h index f816561c..98c8738 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVFixupKinds.h +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVFixupKinds.h @@ -68,27 +68,6 @@ enum Fixups { fixup_riscv_invalid, NumTargetFixupKinds = fixup_riscv_invalid - FirstTargetFixupKind }; - -static inline std::pair<MCFixupKind, MCFixupKind> -getRelocPairForSize(unsigned Size) { - switch (Size) { - default: - llvm_unreachable("unsupported fixup size"); - case 1: - return std::make_pair(FirstLiteralRelocationKind + ELF::R_RISCV_ADD8, - FirstLiteralRelocationKind + ELF::R_RISCV_SUB8); - case 2: - return std::make_pair(FirstLiteralRelocationKind + ELF::R_RISCV_ADD16, - FirstLiteralRelocationKind + ELF::R_RISCV_SUB16); - case 4: - return std::make_pair(FirstLiteralRelocationKind + ELF::R_RISCV_ADD32, - FirstLiteralRelocationKind + ELF::R_RISCV_SUB32); - case 8: - return std::make_pair(FirstLiteralRelocationKind + ELF::R_RISCV_ADD64, - FirstLiteralRelocationKind + ELF::R_RISCV_SUB64); - } -} - } // end namespace llvm::RISCV #endif diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 7b0fe5f..5998653 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -2827,6 +2827,8 @@ static bool selectConstantAddr(SelectionDAG *CurDAG, const SDLoc &DL, static bool isWorthFoldingAdd(SDValue Add) { for (auto *User : Add->users()) { if (User->getOpcode() != ISD::LOAD && User->getOpcode() != ISD::STORE && + User->getOpcode() != RISCVISD::LD_RV32 && + User->getOpcode() != RISCVISD::SD_RV32 && User->getOpcode() != ISD::ATOMIC_LOAD && User->getOpcode() != ISD::ATOMIC_STORE) return false; @@ -2841,6 +2843,9 @@ static bool isWorthFoldingAdd(SDValue Add) { if (User->getOpcode() == ISD::ATOMIC_STORE && cast<AtomicSDNode>(User)->getVal() == Add) return false; + if (User->getOpcode() == RISCVISD::SD_RV32 && + (User->getOperand(0) == Add || User->getOperand(1) == Add)) + return false; if (isStrongerThanMonotonic(cast<MemSDNode>(User)->getSuccessOrdering())) return false; } @@ -2942,8 +2947,8 @@ bool RISCVDAGToDAGISel::SelectAddrRegImm(SDValue Addr, SDValue &Base, /// Similar to SelectAddrRegImm, except that the offset is restricted to uimm9. bool RISCVDAGToDAGISel::SelectAddrRegImm9(SDValue Addr, SDValue &Base, SDValue &Offset) { - // FIXME: Support FrameIndex. Need to teach eliminateFrameIndex that only - // a 9-bit immediate can be folded. + if (SelectAddrFrameIndex(Addr, Base, Offset)) + return true; SDLoc DL(Addr); MVT VT = Addr.getSimpleValueType(); @@ -2953,8 +2958,8 @@ bool RISCVDAGToDAGISel::SelectAddrRegImm9(SDValue Addr, SDValue &Base, if (isUInt<9>(CVal)) { Base = Addr.getOperand(0); - // FIXME: Support FrameIndex. Need to teach eliminateFrameIndex that only - // a 9-bit immediate can be folded. + if (auto *FIN = dyn_cast<FrameIndexSDNode>(Base)) + Base = CurDAG->getTargetFrameIndex(FIN->getIndex(), VT); Offset = CurDAG->getSignedTargetConstant(CVal, DL, VT); return true; } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index b47d89b..adbfbeb 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1819,6 +1819,13 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, case Intrinsic::riscv_seg6_load_mask: case Intrinsic::riscv_seg7_load_mask: case Intrinsic::riscv_seg8_load_mask: + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: return SetRVVLoadStoreInfo(/*PtrOp*/ 0, /*IsStore*/ false, /*IsUnitStrided*/ false, /*UsePtrVal*/ true); case Intrinsic::riscv_seg2_store_mask: @@ -10938,6 +10945,97 @@ static inline SDValue getVCIXISDNodeVOID(SDValue &Op, SelectionDAG &DAG, return DAG.getNode(Type, SDLoc(Op), Op.getValueType(), Operands); } +static SDValue +lowerFixedVectorSegLoadIntrinsics(unsigned IntNo, SDValue Op, + const RISCVSubtarget &Subtarget, + SelectionDAG &DAG) { + bool IsStrided; + switch (IntNo) { + case Intrinsic::riscv_seg2_load_mask: + case Intrinsic::riscv_seg3_load_mask: + case Intrinsic::riscv_seg4_load_mask: + case Intrinsic::riscv_seg5_load_mask: + case Intrinsic::riscv_seg6_load_mask: + case Intrinsic::riscv_seg7_load_mask: + case Intrinsic::riscv_seg8_load_mask: + IsStrided = false; + break; + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: + IsStrided = true; + break; + default: + llvm_unreachable("unexpected intrinsic ID"); + }; + + static const Intrinsic::ID VlsegInts[7] = { + Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask, + Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask, + Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask, + Intrinsic::riscv_vlseg8_mask}; + static const Intrinsic::ID VlssegInts[7] = { + Intrinsic::riscv_vlsseg2_mask, Intrinsic::riscv_vlsseg3_mask, + Intrinsic::riscv_vlsseg4_mask, Intrinsic::riscv_vlsseg5_mask, + Intrinsic::riscv_vlsseg6_mask, Intrinsic::riscv_vlsseg7_mask, + Intrinsic::riscv_vlsseg8_mask}; + + SDLoc DL(Op); + unsigned NF = Op->getNumValues() - 1; + assert(NF >= 2 && NF <= 8 && "Unexpected seg number"); + MVT XLenVT = Subtarget.getXLenVT(); + MVT VT = Op->getSimpleValueType(0); + MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget); + unsigned Sz = NF * ContainerVT.getVectorMinNumElements() * + ContainerVT.getScalarSizeInBits(); + EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF); + + // Operands: (chain, int_id, pointer, mask, vl) or + // (chain, int_id, pointer, offset, mask, vl) + SDValue VL = Op.getOperand(Op.getNumOperands() - 1); + SDValue Mask = Op.getOperand(Op.getNumOperands() - 2); + MVT MaskVT = Mask.getSimpleValueType(); + MVT MaskContainerVT = + ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget); + Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget); + + SDValue IntID = DAG.getTargetConstant( + IsStrided ? VlssegInts[NF - 2] : VlsegInts[NF - 2], DL, XLenVT); + auto *Load = cast<MemIntrinsicSDNode>(Op); + + SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other}); + SmallVector<SDValue, 9> Ops = { + Load->getChain(), + IntID, + DAG.getUNDEF(VecTupTy), + Op.getOperand(2), + Mask, + VL, + DAG.getTargetConstant( + RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT), + DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)}; + // Insert the stride operand. + if (IsStrided) + Ops.insert(std::next(Ops.begin(), 4), Op.getOperand(3)); + + SDValue Result = + DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, + Load->getMemoryVT(), Load->getMemOperand()); + SmallVector<SDValue, 9> Results; + for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) { + SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT, + Result.getValue(0), + DAG.getTargetConstant(RetIdx, DL, MVT::i32)); + Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget)); + } + Results.push_back(Result.getValue(1)); + return DAG.getMergeValues(Results, DL); +} + SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = Op.getConstantOperandVal(1); @@ -10950,57 +11048,16 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, case Intrinsic::riscv_seg5_load_mask: case Intrinsic::riscv_seg6_load_mask: case Intrinsic::riscv_seg7_load_mask: - case Intrinsic::riscv_seg8_load_mask: { - SDLoc DL(Op); - static const Intrinsic::ID VlsegInts[7] = { - Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask, - Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask, - Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask, - Intrinsic::riscv_vlseg8_mask}; - unsigned NF = Op->getNumValues() - 1; - assert(NF >= 2 && NF <= 8 && "Unexpected seg number"); - MVT XLenVT = Subtarget.getXLenVT(); - MVT VT = Op->getSimpleValueType(0); - MVT ContainerVT = getContainerForFixedLengthVector(VT); - unsigned Sz = NF * ContainerVT.getVectorMinNumElements() * - ContainerVT.getScalarSizeInBits(); - EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF); - - // Operands: (chain, int_id, pointer, mask, vl) - SDValue VL = Op.getOperand(Op.getNumOperands() - 1); - SDValue Mask = Op.getOperand(3); - MVT MaskVT = Mask.getSimpleValueType(); - MVT MaskContainerVT = - ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget); - Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget); - - SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT); - auto *Load = cast<MemIntrinsicSDNode>(Op); + case Intrinsic::riscv_seg8_load_mask: + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: + return lowerFixedVectorSegLoadIntrinsics(IntNo, Op, Subtarget, DAG); - SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other}); - SDValue Ops[] = { - Load->getChain(), - IntID, - DAG.getUNDEF(VecTupTy), - Op.getOperand(2), - Mask, - VL, - DAG.getTargetConstant( - RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT), - DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)}; - SDValue Result = - DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, - Load->getMemoryVT(), Load->getMemOperand()); - SmallVector<SDValue, 9> Results; - for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) { - SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT, - Result.getValue(0), - DAG.getTargetConstant(RetIdx, DL, MVT::i32)); - Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget)); - } - Results.push_back(Result.getValue(1)); - return DAG.getMergeValues(Results, DL); - } case Intrinsic::riscv_sf_vc_v_x_se: return getVCIXISDNodeWCHAIN(Op, DAG, RISCVISD::SF_VC_V_X_SE); case Intrinsic::riscv_sf_vc_v_i_se: @@ -22725,8 +22782,14 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, bool IsVarArg = CLI.IsVarArg; EVT PtrVT = getPointerTy(DAG.getDataLayout()); MVT XLenVT = Subtarget.getXLenVT(); + const CallBase *CB = CLI.CB; MachineFunction &MF = DAG.getMachineFunction(); + MachineFunction::CallSiteInfo CSInfo; + + // Set type id for call site info. + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + CSInfo = MachineFunction::CallSiteInfo(*CB); // Analyze the operands of the call, assigning locations to each operand. SmallVector<CCValAssign, 16> ArgLocs; @@ -22984,6 +23047,9 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, if (CLI.CFIType) Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue()); DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge); + if (MF.getTarget().Options.EmitCallGraphSection && CB && + CB->isIndirectCall()) + DAG.addCallSiteInfo(Ret.getNode(), std::move(CSInfo)); return Ret; } @@ -22991,6 +23057,10 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, Chain = DAG.getNode(CallOpc, DL, NodeTys, Ops); if (CLI.CFIType) Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue()); + + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo)); + DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge); Glue = Chain.getValue(1); diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td index 6afc942d..03e6f43 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -1510,21 +1510,6 @@ class VPseudoTiedBinaryCarryIn<VReg RetClass, let VLMul = MInfo.value; } -class VPseudoTernaryNoMask<VReg RetClass, - RegisterClass Op1Class, - DAGOperand Op2Class, - string Constraint> : - RISCVVPseudo<(outs RetClass:$rd), - (ins RetClass:$rs3, Op1Class:$rs1, Op2Class:$rs2, - AVL:$vl, sew:$sew)> { - let mayLoad = 0; - let mayStore = 0; - let hasSideEffects = 0; - let Constraints = !interleave([Constraint, "$rd = $rs3"], ","); - let HasVLOp = 1; - let HasSEWOp = 1; -} - class VPseudoTernaryNoMaskWithPolicy<VReg RetClass, RegisterClass Op1Class, DAGOperand Op2Class, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index 31ea2de..cc2977c 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -910,7 +910,7 @@ foreach vti = AllIntegerVectors in { foreach vti = I64IntegerVectors in { let Predicates = [HasVInstructionsI64] in { def : Pat<(add (vti.Vector vti.RegClass:$rs1), - (vti.Vector (SplatPat_imm64_neg i64:$rs2))), + (vti.Vector (SplatPat_imm64_neg (i64 GPR:$rs2)))), (!cast<Instruction>("PseudoVSUB_VX_"#vti.LMul.MX) (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 695223b..acbccdd 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -2123,7 +2123,7 @@ foreach vti = AllIntegerVectors in { foreach vti = I64IntegerVectors in { let Predicates = [HasVInstructionsI64] in { def : Pat<(riscv_add_vl (vti.Vector vti.RegClass:$rs1), - (vti.Vector (SplatPat_imm64_neg i64:$rs2)), + (vti.Vector (SplatPat_imm64_neg (i64 GPR:$rs2))), vti.RegClass:$passthru, (vti.Mask VMV0:$vm), VLOpFrag), (!cast<Instruction>("PseudoVSUB_VX_"#vti.LMul.MX#"_MASK") vti.RegClass:$passthru, vti.RegClass:$rs1, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td index c0f7ab1..4c31ce4 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td @@ -590,12 +590,12 @@ let Predicates = [HasVendorXTHeadBb, IsRV64] in { def : PatGprImm<riscv_rorw, TH_SRRIW, uimm5>; def : Pat<(riscv_rolw GPR:$rs1, uimm5:$rs2), (TH_SRRIW GPR:$rs1, (ImmSubFrom32 uimm5:$rs2))>; -def : Pat<(sra (bswap i64:$rs1), (i64 32)), - (TH_REVW i64:$rs1)>; -def : Pat<(binop_allwusers<srl> (bswap i64:$rs1), (i64 32)), - (TH_REVW i64:$rs1)>; -def : Pat<(riscv_clzw i64:$rs1), - (TH_FF0 (i64 (SLLI (i64 (XORI i64:$rs1, -1)), 32)))>; +def : Pat<(i64 (sra (bswap GPR:$rs1), (i64 32))), + (TH_REVW GPR:$rs1)>; +def : Pat<(binop_allwusers<srl> (bswap GPR:$rs1), (i64 32)), + (TH_REVW GPR:$rs1)>; +def : Pat<(riscv_clzw GPR:$rs1), + (TH_FF0 (i64 (SLLI (i64 (XORI GPR:$rs1, -1)), 32)))>; } // Predicates = [HasVendorXTHeadBb, IsRV64] let Predicates = [HasVendorXTHeadBs] in { @@ -697,11 +697,13 @@ def uimm2_4 : Operand<XLenVT>, ImmLeaf<XLenVT, [{ }], uimm2_4_XFORM>; let Predicates = [HasVendorXTHeadMemPair, IsRV64] in { -def : Pat<(th_lwud i64:$rs1, uimm2_3:$uimm2_3), (TH_LWUD i64:$rs1, uimm2_3:$uimm2_3, 3)>; -def : Pat<(th_ldd i64:$rs1, uimm2_4:$uimm2_4), (TH_LDD i64:$rs1, uimm2_4:$uimm2_4, 4)>; +def : Pat<(th_lwud GPR:$rs1, (i64 uimm2_3:$uimm2_3)), + (TH_LWUD GPR:$rs1, uimm2_3:$uimm2_3, 3)>; +def : Pat<(th_ldd GPR:$rs1, (i64 uimm2_4:$uimm2_4)), + (TH_LDD GPR:$rs1, uimm2_4:$uimm2_4, 4)>; -def : Pat<(th_sdd i64:$rd1, i64:$rd2, i64:$rs1, uimm2_4:$uimm2_4), - (TH_SDD i64:$rd1, i64:$rd2, i64:$rs1, uimm2_4:$uimm2_4, 4)>; +def : Pat<(th_sdd (i64 GPR:$rd1), GPR:$rd2, GPR:$rs1, uimm2_4:$uimm2_4), + (TH_SDD GPR:$rd1, GPR:$rd2, GPR:$rs1, uimm2_4:$uimm2_4, 4)>; } let Predicates = [HasVendorXTHeadMemPair] in { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td index a250ac8..1efe616 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td @@ -1128,13 +1128,13 @@ let Predicates = [HasStdExtZvkned] in { let Predicates = [HasStdExtZvknha] in { defm : VPatBinaryV_VV_NoMask<"int_riscv_vsha2ch", "PseudoVSHA2CH", I32IntegerVectors>; - defm : VPatBinaryV_VV_NoMask<"int_riscv_vsha2cl", "PseudoVSHA2CH", I32IntegerVectors>; + defm : VPatBinaryV_VV_NoMask<"int_riscv_vsha2cl", "PseudoVSHA2CL", I32IntegerVectors>; defm : VPatBinaryV_VV_NoMask<"int_riscv_vsha2ms", "PseudoVSHA2MS", I32IntegerVectors, isSEWAware=true>; } // Predicates = [HasStdExtZvknha] let Predicates = [HasStdExtZvknhb] in { defm : VPatBinaryV_VV_NoMask<"int_riscv_vsha2ch", "PseudoVSHA2CH", I32I64IntegerVectors>; - defm : VPatBinaryV_VV_NoMask<"int_riscv_vsha2cl", "PseudoVSHA2CH", I32I64IntegerVectors>; + defm : VPatBinaryV_VV_NoMask<"int_riscv_vsha2cl", "PseudoVSHA2CL", I32I64IntegerVectors>; defm : VPatBinaryV_VV_NoMask<"int_riscv_vsha2ms", "PseudoVSHA2MS", I32I64IntegerVectors, isSEWAware=true>; } // Predicates = [HasStdExtZvknhb] diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp index 214536d..7e58b6f 100644 --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp @@ -576,6 +576,7 @@ bool RISCVRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator II, int64_t Val = Offset.getFixed(); int64_t Lo12 = SignExtend64<12>(Val); unsigned Opc = MI.getOpcode(); + if (Opc == RISCV::ADDI && !isInt<12>(Val)) { // We chose to emit the canonical immediate sequence rather than folding // the offset into the using add under the theory that doing so doesn't @@ -588,6 +589,9 @@ bool RISCVRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator II, (Lo12 & 0b11111) != 0) { // Prefetch instructions require the offset to be 32 byte aligned. MI.getOperand(FIOperandNum + 1).ChangeToImmediate(0); + } else if (Opc == RISCV::MIPS_PREFETCH && !isUInt<9>(Val)) { + // MIPS Prefetch instructions require the offset to be 9 bits encoded. + MI.getOperand(FIOperandNum + 1).ChangeToImmediate(0); } else if ((Opc == RISCV::PseudoRV32ZdinxLD || Opc == RISCV::PseudoRV32ZdinxSD) && Lo12 >= 2044) { diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 4eef8de..0d5eb86 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1280,8 +1280,13 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, } else { Ops = {RISCV::VFCVT_X_F_V}; } - return std::max(SrcLT.first, LT.first) * - getRISCVInstructionCost(Ops, LT.second, CostKind); + + // We need to use the source LMUL in the case of a narrowing op, and the + // destination LMUL otherwise. + if (SrcEltSz > DstEltSz) + return SrcLT.first * + getRISCVInstructionCost(Ops, SrcLT.second, CostKind); + return LT.first * getRISCVInstructionCost(Ops, LT.second, CostKind); } break; } @@ -2622,18 +2627,17 @@ void RISCVTTIImpl::getUnrollingPreferences( if (L->getNumBlocks() > 4) return; - // Don't unroll vectorized loops, including the remainder loop - if (getBooleanLoopAttribute(L, "llvm.loop.isvectorized")) - return; - // Scan the loop: don't unroll loops with calls as this could prevent - // inlining. + // inlining. Don't unroll auto-vectorized loops either, though do allow + // unrolling of the scalar remainder. + bool IsVectorized = getBooleanLoopAttribute(L, "llvm.loop.isvectorized"); InstructionCost Cost = 0; for (auto *BB : L->getBlocks()) { for (auto &I : *BB) { - // Initial setting - Don't unroll loops containing vectorized - // instructions. - if (I.getType()->isVectorTy()) + // Both auto-vectorized loops and the scalar remainder have the + // isvectorized attribute, so differentiate between them by the presence + // of vector instructions. + if (IsVectorized && I.getType()->isVectorTy()) return; if (isa<CallInst>(I) || isa<InvokeInst>(I)) { diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp index c1cc19b..050de3d 100644 --- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp +++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp @@ -646,8 +646,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) { if (!Src || Src->hasUnmodeledSideEffects() || Src->getParent() != MI.getParent() || !RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) || - !RISCVII::hasVLOp(Src->getDesc().TSFlags) || - !RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags)) + !RISCVII::hasVLOp(Src->getDesc().TSFlags)) return false; // Src's dest needs to have the same EEW as MI's input. @@ -681,12 +680,14 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) { *Src->getParent()->getParent())); } - // If MI was tail agnostic and the VL didn't increase, preserve it. - int64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED; - if ((MI.getOperand(5).getImm() & RISCVVType::TAIL_AGNOSTIC) && - RISCV::isVLKnownLE(MI.getOperand(3), SrcVL)) - Policy |= RISCVVType::TAIL_AGNOSTIC; - Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy); + if (RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags)) { + // If MI was tail agnostic and the VL didn't increase, preserve it. + int64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED; + if ((MI.getOperand(5).getImm() & RISCVVType::TAIL_AGNOSTIC) && + RISCV::isVLKnownLE(MI.getOperand(3), SrcVL)) + Policy |= RISCVVType::TAIL_AGNOSTIC; + Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy); + } MRI->constrainRegClass(Src->getOperand(0).getReg(), MRI->getRegClass(MI.getOperand(0).getReg())); diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVAsmBackend.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVAsmBackend.cpp index ef84d43..5710cf2 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVAsmBackend.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVAsmBackend.cpp @@ -21,8 +21,7 @@ public: SPIRVAsmBackend(llvm::endianness Endian) : MCAsmBackend(Endian) {} void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override {} + uint8_t *Data, uint64_t Value, bool IsResolved) override {} std::unique_ptr<MCObjectTargetWriter> createObjectTargetWriter() const override { diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 6ec7544..25cdf72 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -148,6 +148,7 @@ struct ConvertBuiltin { bool IsSaturated; bool IsRounded; bool IsBfloat16; + bool IsTF32; FPRoundingMode::FPRoundingMode RoundingMode; }; @@ -230,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, // - "__spirv_SubgroupImageMediaBlockReadINTEL" // - "__spirv_SubgroupImageMediaBlockWriteINTEL" // - "__spirv_Convert" + // - "__spirv_Round" // - "__spirv_UConvert" // - "__spirv_SConvert" // - "__spirv_FConvert" @@ -242,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, "SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|" "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|" "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|" - "Convert|" + "Convert|Round|" "UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?"); std::smatch Match; if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) { @@ -697,7 +699,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0)); + return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, + Register(0)); Register ScopeRegister = buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR); @@ -2677,8 +2680,20 @@ static bool generateConvertInst(const StringRef DemangledCall, } } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeFloat)) { - // Float -> Float - Opcode = SPIRV::OpFConvert; + if (Builtin->IsTF32) { + const auto *ST = static_cast<const SPIRVSubtarget *>( + &MIRBuilder.getMF().getSubtarget()); + if (!ST->canUseExtension( + SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) + NeedExtMsg = "SPV_INTEL_tensor_float32_conversion"; + IsRightComponentsNumber = + GR->getScalarOrVectorComponentCount(Call->Arguments[0]) == + GR->getScalarOrVectorComponentCount(Call->ReturnRegister); + Opcode = SPIRV::OpRoundFToTF32INTEL; + } else { + // Float -> Float + Opcode = SPIRV::OpFConvert; + } } } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index ea78dcd..d08560b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -1461,6 +1461,8 @@ class ConvertBuiltin<string name, InstructionSet set> { bit IsRounded = !not(!eq(!find(name, "_rt"), -1)); bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)), !not(!eq(!find(name, "bfloat16"), -1))); + bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)), + !not(!eq(!find(name, "tensor_float32"), -1))); FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE, !not(!eq(!find(name, "_rtz"), -1)) : RTZ, !not(!eq(!find(name, "_rtp"), -1)) : RTP, @@ -1472,7 +1474,7 @@ class ConvertBuiltin<string name, InstructionSet set> { def ConvertBuiltins : GenericTable { let FilterClass = "ConvertBuiltin"; let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", - "IsRounded", "IsBfloat16", "RoundingMode"]; + "IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"]; string TypeOf_Set = "InstructionSet"; string TypeOf_RoundingMode = "FPRoundingMode"; } @@ -1556,6 +1558,25 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in { def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>; } +// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion +// Multiclass used to define at the same time both a demangled builtin record +// and a corresponding convert builtin record. +multiclass DemangledTF32RoundBuiltin<string name1, string name2> { + // Create records for scalar and vector conversions. + foreach i = ["", "2", "3", "4", "8", "16"] in { + def : DemangledBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>; + def : ConvertBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std>; + } +} + +defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">; +defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">; + +foreach conv = ["FToTF32INTEL"] in { + def : DemangledBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std, Convert, 1, 1>; + def : ConvertBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std>; +} + //===----------------------------------------------------------------------===// // Class defining a vector data load/store builtin record used for lowering // into OpExtInst instruction. diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 2726203..d9265f4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -102,7 +102,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>> SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}, {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}, {"SPV_KHR_float_controls2", - SPIRV::Extension::Extension::SPV_KHR_float_controls2}}; + SPIRV::Extension::Extension::SPV_KHR_float_controls2}, + {"SPV_INTEL_tensor_float32_conversion", + SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 3c631ce..947b574 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -194,6 +194,42 @@ class SPIRVEmitIntrinsics void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B); + // Tries to walk the type accessed by the given GEP instruction. + // For each nested type access, one of the 2 callbacks is called: + // - OnLiteralIndexing when the index is a known constant value. + // Parameters: + // PointedType: the pointed type resulting of this indexing. + // If the parent type is an array, this is the index in the array. + // If the parent type is a struct, this is the field index. + // Index: index of the element in the parent type. + // - OnDynamnicIndexing when the index is a non-constant value. + // This callback is only called when indexing into an array. + // Parameters: + // ElementType: the type of the elements stored in the parent array. + // Offset: the Value* containing the byte offset into the array. + // Return true if an error occured during the walk, false otherwise. + bool walkLogicalAccessChain( + GetElementPtrInst &GEP, + const std::function<void(Type *PointedType, uint64_t Index)> + &OnLiteralIndexing, + const std::function<void(Type *ElementType, Value *Offset)> + &OnDynamicIndexing); + + // Returns the type accessed using the given GEP instruction by relying + // on the GEP type. + // FIXME: GEP types are not supposed to be used to retrieve the pointed + // type. This must be fixed. + Type *getGEPType(GetElementPtrInst *GEP); + + // Returns the type accessed using the given GEP instruction by walking + // the source type using the GEP indices. + // FIXME: without help from the frontend, this method cannot reliably retrieve + // the stored type, nor can robustly determine the depth of the type + // we are accessing. + Type *getGEPTypeLogical(GetElementPtrInst *GEP); + + Instruction *buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP); + public: static char ID; SPIRVEmitIntrinsics(SPIRVTargetMachine *TM = nullptr) @@ -246,6 +282,17 @@ bool expectIgnoredInIRTranslation(const Instruction *I) { } } +// Returns the source pointer from `I` ignoring intermediate ptrcast. +Value *getPointerRoot(Value *I) { + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::spv_ptrcast) { + Value *V = II->getArgOperand(0); + return getPointerRoot(V); + } + } + return I; +} + } // namespace char SPIRVEmitIntrinsics::ID = 0; @@ -555,7 +602,112 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy, Ty = RefTy; } -Type *getGEPType(GetElementPtrInst *Ref) { +bool SPIRVEmitIntrinsics::walkLogicalAccessChain( + GetElementPtrInst &GEP, + const std::function<void(Type *, uint64_t)> &OnLiteralIndexing, + const std::function<void(Type *, Value *)> &OnDynamicIndexing) { + // We only rewrite i8* GEP. Other should be left as-is. + // Valid i8* GEP must always have a single index. + assert(GEP.getSourceElementType() == + IntegerType::getInt8Ty(CurrF->getContext())); + assert(GEP.getNumIndices() == 1); + + auto &DL = CurrF->getDataLayout(); + Value *Src = getPointerRoot(GEP.getPointerOperand()); + Type *CurType = deduceElementType(Src, true); + + Value *Operand = *GEP.idx_begin(); + ConstantInt *CI = dyn_cast<ConstantInt>(Operand); + if (!CI) { + ArrayType *AT = dyn_cast<ArrayType>(CurType); + // Operand is not constant. Either we have an array and accept it, or we + // give up. + if (AT) + OnDynamicIndexing(AT->getElementType(), Operand); + return AT == nullptr; + } + + assert(CI); + uint64_t Offset = CI->getZExtValue(); + + do { + if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) { + uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8; + assert(Offset < AT->getNumElements() * EltTypeSize); + uint64_t Index = Offset / EltTypeSize; + Offset = Offset - (Index * EltTypeSize); + CurType = AT->getElementType(); + OnLiteralIndexing(CurType, Index); + } else if (StructType *ST = dyn_cast<StructType>(CurType)) { + uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8; + assert(Offset < StructSize); + (void)StructSize; + const auto &STL = DL.getStructLayout(ST); + unsigned Element = STL->getElementContainingOffset(Offset); + Offset -= STL->getElementOffset(Element); + CurType = ST->getElementType(Element); + OnLiteralIndexing(CurType, Element); + } else { + // Vector type indexing should not use GEP. + // So if we have an index left, something is wrong. Giving up. + return true; + } + } while (Offset > 0); + + return false; +} + +Instruction * +SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) { + auto &DL = CurrF->getDataLayout(); + IRBuilder<> B(GEP.getParent()); + B.SetInsertPoint(&GEP); + + std::vector<Value *> Indices; + Indices.push_back(ConstantInt::get( + IntegerType::getInt32Ty(CurrF->getContext()), 0, /* Signed= */ false)); + walkLogicalAccessChain( + GEP, + [&Indices, &B](Type *EltType, uint64_t Index) { + Indices.push_back( + ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false)); + }, + [&Indices, &B, &DL](Type *EltType, Value *Offset) { + uint32_t EltTypeSize = DL.getTypeSizeInBits(EltType) / 8; + Value *Index = B.CreateUDiv( + Offset, ConstantInt::get(Offset->getType(), EltTypeSize, + /* Signed= */ false)); + Indices.push_back(Index); + }); + + SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()}; + SmallVector<Value *, 4> Args; + Args.push_back(B.getInt1(GEP.isInBounds())); + Args.push_back(GEP.getOperand(0)); + llvm::append_range(Args, Indices); + auto *NewI = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args}); + replaceAllUsesWithAndErase(B, &GEP, NewI); + return NewI; +} + +Type *SPIRVEmitIntrinsics::getGEPTypeLogical(GetElementPtrInst *GEP) { + + Type *CurType = GEP->getResultElementType(); + + bool Interrupted = walkLogicalAccessChain( + *GEP, [&CurType](Type *EltType, uint64_t Index) { CurType = EltType; }, + [&CurType](Type *EltType, Value *Index) { CurType = EltType; }); + + return Interrupted ? GEP->getResultElementType() : CurType; +} + +Type *SPIRVEmitIntrinsics::getGEPType(GetElementPtrInst *Ref) { + if (Ref->getSourceElementType() == + IntegerType::getInt8Ty(CurrF->getContext()) && + TM->getSubtargetImpl()->isLogicalSPIRV()) { + return getGEPTypeLogical(Ref); + } + Type *Ty = nullptr; // TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything // useful here @@ -1395,6 +1547,13 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) { } Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) { + if (I.getSourceElementType() == IntegerType::getInt8Ty(CurrF->getContext()) && + TM->getSubtargetImpl()->isLogicalSPIRV()) { + Instruction *Result = buildLogicalAccessChainFromGEP(I); + if (Result) + return Result; + } + IRBuilder<> B(I.getParent()); B.SetInsertPoint(&I); SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()}; @@ -1588,7 +1747,24 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, } if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) { Value *Pointer = GEPI->getPointerOperand(); - Type *OpTy = GEPI->getSourceElementType(); + Type *OpTy = nullptr; + + // Knowing the accessed type is mandatory for logical SPIR-V. Sadly, + // the GEP source element type should not be used for this purpose, and + // the alternative type-scavenging method is not working. + // Physical SPIR-V can work around this, but not logical, hence still + // try to rely on the broken type scavenging for logical. + bool IsRewrittenGEP = + GEPI->getSourceElementType() == IntegerType::getInt8Ty(I->getContext()); + if (IsRewrittenGEP && TM->getSubtargetImpl()->isLogicalSPIRV()) { + Value *Src = getPointerRoot(Pointer); + OpTy = GR->findDeducedElementType(Src); + } + + // In all cases, fall back to the GEP type if type scavenging failed. + if (!OpTy) + OpTy = GEPI->getSourceElementType(); + replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B); if (isNestedPointer(OpTy)) insertTodoType(Pointer); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 049ba02..f0b938d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -445,6 +445,9 @@ def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938 def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>; def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>; +// SPV_INTEL_tensor_float32_conversion +def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>; + // 3.42.12 Composite Instructions def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx), diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 5cda6a0..7505507 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -74,17 +74,20 @@ class SPIRVLegalizePointerCast : public FunctionPass { // Returns the loaded value. Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType, FixedVectorType *TargetType, Value *Source) { - // We expect the codegen to avoid doing implicit bitcast from a load. - assert(TargetType->getElementType() == SourceType->getElementType()); - assert(TargetType->getNumElements() < SourceType->getNumElements()); - + assert(TargetType->getNumElements() <= SourceType->getNumElements()); LoadInst *NewLoad = B.CreateLoad(SourceType, Source); buildAssignType(B, SourceType, NewLoad); + Value *AssignValue = NewLoad; + if (TargetType->getElementType() != SourceType->getElementType()) { + AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast, + {TargetType, SourceType}, {NewLoad}); + buildAssignType(B, TargetType, AssignValue); + } SmallVector<int> Mask(/* Size= */ TargetType->getNumElements()); for (unsigned I = 0; I < TargetType->getNumElements(); ++I) Mask[I] = I; - Value *Output = B.CreateShuffleVector(NewLoad, NewLoad, Mask); + Value *Output = B.CreateShuffleVector(AssignValue, AssignValue, Mask); buildAssignType(B, TargetType, Output); return Output; } @@ -135,8 +138,9 @@ class SPIRVLegalizePointerCast : public FunctionPass { Output = loadFirstValueFromAggregate(B, SVT->getElementType(), OriginalOperand, LI); } - // Destination is a smaller vector than source. + // Destination is a smaller vector than source or different vector type. // - float3 v3 = vector4; + // - float4 v2 = int4; else if (SVT && DVT) Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand); // Destination is the scalar type stored at the start of an aggregate. diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 721f64a..1995e0f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -335,6 +335,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal(); } + getActionDefinitionsBuilder(G_IS_FPCLASS).custom(); + getLegacyLegalizerInfo().computeTables(); verify(*ST.getInstrInfo()); } @@ -355,9 +357,14 @@ static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType, bool SPIRVLegalizerInfo::legalizeCustom( LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const { - auto Opc = MI.getOpcode(); MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); - if (Opc == TargetOpcode::G_ICMP) { + switch (MI.getOpcode()) { + default: + // TODO: implement legalization for other opcodes. + return true; + case TargetOpcode::G_IS_FPCLASS: + return legalizeIsFPClass(Helper, MI, LocObserver); + case TargetOpcode::G_ICMP: { assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); auto &Op0 = MI.getOperand(2); auto &Op1 = MI.getOperand(3); @@ -378,6 +385,238 @@ bool SPIRVLegalizerInfo::legalizeCustom( } return true; } - // TODO: implement legalization for other opcodes. + } +} + +// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted +// to ensure that all instructions created during the lowering have SPIR-V types +// assigned to them. +bool SPIRVLegalizerInfo::legalizeIsFPClass( + LegalizerHelper &Helper, MachineInstr &MI, + LostDebugLocObserver &LocObserver) const { + auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs(); + FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm()); + + auto &MIRBuilder = Helper.MIRBuilder; + auto &MF = MIRBuilder.getMF(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + Type *LLVMDstTy = + IntegerType::get(MIRBuilder.getContext(), DstTy.getScalarSizeInBits()); + if (DstTy.isVector()) + LLVMDstTy = VectorType::get(LLVMDstTy, DstTy.getElementCount()); + SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType( + LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, + /*EmitIR*/ true); + + unsigned BitSize = SrcTy.getScalarSizeInBits(); + const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType()); + + LLT IntTy = LLT::scalar(BitSize); + Type *LLVMIntTy = IntegerType::get(MIRBuilder.getContext(), BitSize); + if (SrcTy.isVector()) { + IntTy = LLT::vector(SrcTy.getElementCount(), IntTy); + LLVMIntTy = VectorType::get(LLVMIntTy, SrcTy.getElementCount()); + } + SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType( + LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, + /*EmitIR*/ true); + + // Clang doesn't support capture of structured bindings: + LLT DstTyCopy = DstTy; + const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) { + // Assign this MI's (assumed only) destination to one of the two types we + // expect: either the G_IS_FPCLASS's destination type, or the integer type + // bitcast from the source type. + LLT MITy = MRI.getType(MI.getReg(0)); + assert((MITy == IntTy || MITy == DstTyCopy) && + "Unexpected LLT type while lowering G_IS_FPCLASS"); + auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy; + GR->assignSPIRVTypeToVReg(SPVTy, MI.getReg(0), MF); + return MI; + }; + + // Helper to build and assign a constant in one go + const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder { + if (!Ty.isFixedVector()) + return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C)); + auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C); + assert((Ty == IntTy || Ty == DstTyCopy) && + "Unexpected LLT type while lowering constant for G_IS_FPCLASS"); + SPIRVType *VecEltTy = GR->getOrCreateSPIRVType( + (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder, + SPIRV::AccessQualifier::ReadWrite, + /*EmitIR*/ true); + GR->assignSPIRVTypeToVReg(VecEltTy, ScalarC.getReg(0), MF); + return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Ty, ScalarC)); + }; + + if (Mask == fcNone) { + MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 0)); + MI.eraseFromParent(); + return true; + } + if (Mask == fcAllFlags) { + MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 1)); + MI.eraseFromParent(); + return true; + } + + // Note that rather than creating a COPY here (between a floating-point and + // integer type of the same size) we create a SPIR-V bitcast immediately. We + // can't create a G_BITCAST because the LLTs are the same, and we can't seem + // to correctly lower COPYs to SPIR-V bitcasts at this moment. + Register ResVReg = MRI.createGenericVirtualRegister(IntTy); + MRI.setRegClass(ResVReg, GR->getRegClass(SPIRVIntTy)); + GR->assignSPIRVTypeToVReg(SPIRVIntTy, ResVReg, Helper.MIRBuilder.getMF()); + auto AsInt = MIRBuilder.buildInstr(SPIRV::OpBitcast) + .addDef(ResVReg) + .addUse(GR->getSPIRVTypeID(SPIRVIntTy)) + .addUse(SrcReg); + AsInt = assignSPIRVTy(std::move(AsInt)); + + // Various masks. + APInt SignBit = APInt::getSignMask(BitSize); + APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign. + APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit. + APInt ExpMask = Inf; + APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf; + APInt QNaNBitMask = + APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1); + APInt InversionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits()); + + auto SignBitC = buildSPIRVConstant(IntTy, SignBit); + auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask); + auto InfC = buildSPIRVConstant(IntTy, Inf); + auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask); + auto ZeroC = buildSPIRVConstant(IntTy, 0); + + auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC)); + auto Sign = assignSPIRVTy( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs)); + + auto Res = buildSPIRVConstant(DstTy, 0); + + const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) { + Res = assignSPIRVTy( + MIRBuilder.buildOr(DstTyCopy, Res, assignSPIRVTy(std::move(ToAppend)))); + }; + + // Tests that involve more than one class should be processed first. + if ((Mask & fcFinite) == fcFinite) { + // finite(V) ==> abs(V) u< exp_mask + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs, + ExpMaskC)); + Mask &= ~fcFinite; + } else if ((Mask & fcFinite) == fcPosFinite) { + // finite(V) && V > 0 ==> V u< exp_mask + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt, + ExpMaskC)); + Mask &= ~fcPosFinite; + } else if ((Mask & fcFinite) == fcNegFinite) { + // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1 + auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, + DstTy, Abs, ExpMaskC)); + appendToRes(MIRBuilder.buildAnd(DstTy, Cmp, Sign)); + Mask &= ~fcNegFinite; + } + + if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) { + // fcZero | fcSubnormal => test all exponent bits are 0 + // TODO: Handle sign bit specific cases + // TODO: Handle inverted case + if (PartialCheck == (fcZero | fcSubnormal)) { + auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC)); + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + ExpBits, ZeroC)); + Mask &= ~PartialCheck; + } + } + + // Check for individual classes. + if (FPClassTest PartialCheck = Mask & fcZero) { + if (PartialCheck == fcPosZero) + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + AsInt, ZeroC)); + else if (PartialCheck == fcZero) + appendToRes( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC)); + else // fcNegZero + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + AsInt, SignBitC)); + } + + if (FPClassTest PartialCheck = Mask & fcSubnormal) { + // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set) + // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set) + auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs; + auto OneC = buildSPIRVConstant(IntTy, 1); + auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC); + auto SubnormalRes = assignSPIRVTy( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne, + buildSPIRVConstant(IntTy, AllOneMantissa))); + if (PartialCheck == fcNegSubnormal) + SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign); + appendToRes(std::move(SubnormalRes)); + } + + if (FPClassTest PartialCheck = Mask & fcInf) { + if (PartialCheck == fcPosInf) + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + AsInt, InfC)); + else if (PartialCheck == fcInf) + appendToRes( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC)); + else { // fcNegInf + APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt(); + auto NegInfC = buildSPIRVConstant(IntTy, NegInf); + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + AsInt, NegInfC)); + } + } + + if (FPClassTest PartialCheck = Mask & fcNan) { + auto InfWithQnanBitC = buildSPIRVConstant(IntTy, Inf | QNaNBitMask); + if (PartialCheck == fcNan) { + // isnan(V) ==> abs(V) u> int(inf) + appendToRes( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC)); + } else if (PartialCheck == fcQNan) { + // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit) + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs, + InfWithQnanBitC)); + } else { // fcSNan + // issignaling(V) ==> abs(V) u> unsigned(Inf) && + // abs(V) u< (unsigned(Inf) | quiet_bit) + auto IsNan = assignSPIRVTy( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC)); + auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp( + CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC)); + appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan)); + } + } + + if (FPClassTest PartialCheck = Mask & fcNormal) { + // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u< + // (max_exp-1)) + APInt ExpLSB = ExpMask & ~(ExpMask.shl(1)); + auto ExpMinusOne = assignSPIRVTy( + MIRBuilder.buildSub(IntTy, Abs, buildSPIRVConstant(IntTy, ExpLSB))); + APInt MaxExpMinusOne = ExpMask - ExpLSB; + auto NormalRes = assignSPIRVTy( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne, + buildSPIRVConstant(IntTy, MaxExpMinusOne))); + if (PartialCheck == fcNegNormal) + NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign); + else if (PartialCheck == fcPosNormal) { + auto PosSign = assignSPIRVTy(MIRBuilder.buildXor( + DstTy, Sign, buildSPIRVConstant(DstTy, InversionMask))); + NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign); + } + appendToRes(std::move(NormalRes)); + } + + MIRBuilder.buildCopy(DstReg, Res); + MI.eraseFromParent(); return true; } diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h index 6335f21..eeefa42 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -30,6 +30,10 @@ public: bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override; SPIRVLegalizerInfo(const SPIRVSubtarget &ST); + +private: + bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI, + LostDebugLocObserver &LocObserver) const; }; } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index ad976e5..0cd9d78 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1564,6 +1564,13 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL); } break; + case SPIRV::OpRoundFToTF32INTEL: + if (ST.canUseExtension( + SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion); + Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL); + } + break; case SPIRV::OpVariableLengthArrayINTEL: case SPIRV::OpSaveMemoryINTEL: case SPIRV::OpRestoreMemoryINTEL: diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 548e9b7..614e83a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>; defm SPV_INTEL_2d_block_io : ExtensionOperand<122>; defm SPV_INTEL_int4 : ExtensionOperand<123>; defm SPV_KHR_float_controls2 : ExtensionOperand<124>; +defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -529,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>; defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>; defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>; +defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h index 43bf6e9..60c4e2d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h @@ -59,6 +59,8 @@ public: Intrinsic::ID IID) const override; Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, Value *NewV) const override; + + bool allowVectorElementIndexingUsingGEP() const override { return false; } }; } // namespace llvm diff --git a/llvm/lib/Target/Sparc/MCTargetDesc/SparcAsmBackend.cpp b/llvm/lib/Target/Sparc/MCTargetDesc/SparcAsmBackend.cpp index ba023af..bc60842 100644 --- a/llvm/lib/Target/Sparc/MCTargetDesc/SparcAsmBackend.cpp +++ b/llvm/lib/Target/Sparc/MCTargetDesc/SparcAsmBackend.cpp @@ -127,8 +127,7 @@ public: std::optional<MCFixupKind> getFixupKind(StringRef Name) const override; MCFixupKindInfo getFixupKindInfo(MCFixupKind Kind) const override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool writeNopData(raw_ostream &OS, uint64_t Count, const MCSubtargetInfo *STI) const override { @@ -253,21 +252,19 @@ MCFixupKindInfo SparcAsmBackend::getFixupKindInfo(MCFixupKind Kind) const { } void SparcAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { maybeAddReloc(F, Fixup, Target, Value, IsResolved); if (!IsResolved) return; Value = adjustFixupValue(Fixup.getKind(), Value); unsigned NumBytes = getFixupKindNumBytes(Fixup.getKind()); - unsigned Offset = Fixup.getOffset(); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. for (unsigned i = 0; i != NumBytes; ++i) { unsigned Idx = Endian == llvm::endianness::little ? i : (NumBytes - 1) - i; - Data[Offset + Idx] |= uint8_t((Value >> (i * 8)) & 0xff); + Data[Idx] |= uint8_t((Value >> (i * 8)) & 0xff); } } diff --git a/llvm/lib/Target/SystemZ/MCTargetDesc/SystemZMCAsmBackend.cpp b/llvm/lib/Target/SystemZ/MCTargetDesc/SystemZMCAsmBackend.cpp index d5f8492..d692cbe 100644 --- a/llvm/lib/Target/SystemZ/MCTargetDesc/SystemZMCAsmBackend.cpp +++ b/llvm/lib/Target/SystemZ/MCTargetDesc/SystemZMCAsmBackend.cpp @@ -113,8 +113,7 @@ public: std::optional<MCFixupKind> getFixupKind(StringRef Name) const override; MCFixupKindInfo getFixupKindInfo(MCFixupKind Kind) const override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool writeNopData(raw_ostream &OS, uint64_t Count, const MCSubtargetInfo *STI) const override; }; @@ -152,20 +151,18 @@ MCFixupKindInfo SystemZMCAsmBackend::getFixupKindInfo(MCFixupKind Kind) const { } void SystemZMCAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (Target.getSpecifier()) IsResolved = false; maybeAddReloc(F, Fixup, Target, Value, IsResolved); MCFixupKind Kind = Fixup.getKind(); if (mc::isRelocation(Kind)) return; - unsigned Offset = Fixup.getOffset(); unsigned BitSize = getFixupKindInfo(Kind).TargetSize; unsigned Size = (BitSize + 7) / 8; - assert(Offset + Size <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + Size <= F.getSize() && "Invalid fixup offset!"); // Big-endian insertion of Size bytes. Value = extractBitsForFixup(Kind, Value, Fixup, getContext()); @@ -173,7 +170,7 @@ void SystemZMCAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, Value &= ((uint64_t)1 << BitSize) - 1; unsigned ShiftValue = (Size * 8) - 8; for (unsigned I = 0; I != Size; ++I) { - Data[Offset + I] |= uint8_t(Value >> ShiftValue); + Data[I] |= uint8_t(Value >> ShiftValue); ShiftValue -= 8; } } diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp index e30d723..fb0a47d 100644 --- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp +++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp @@ -9044,7 +9044,7 @@ static unsigned detectEvenOddMultiplyOperand(const SelectionDAG &DAG, if (unsigned(ShuffleMask[Elt]) != 2 * Elt) CanUseEven = false; if (unsigned(ShuffleMask[Elt]) != 2 * Elt + 1) - CanUseEven = true; + CanUseOdd = false; } Op = Op.getOperand(0); if (CanUseEven) diff --git a/llvm/lib/Target/VE/MCTargetDesc/VEAsmBackend.cpp b/llvm/lib/Target/VE/MCTargetDesc/VEAsmBackend.cpp index f987621..c1b9d9f 100644 --- a/llvm/lib/Target/VE/MCTargetDesc/VEAsmBackend.cpp +++ b/llvm/lib/Target/VE/MCTargetDesc/VEAsmBackend.cpp @@ -112,8 +112,7 @@ public: } void applyFixup(const MCFragment &, const MCFixup &, const MCValue &, - MutableArrayRef<char>, uint64_t Value, - bool IsResolved) override; + uint8_t *, uint64_t Value, bool IsResolved) override; bool mayNeedRelaxation(unsigned Opcode, ArrayRef<MCOperand> Operands, const MCSubtargetInfo &STI) const override { @@ -152,7 +151,7 @@ public: } // end anonymous namespace void VEAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, MutableArrayRef<char> Data, + const MCValue &Target, uint8_t *Data, uint64_t Value, bool IsResolved) { switch (Fixup.getKind()) { case VE::fixup_ve_tls_gd_hi32: @@ -173,14 +172,14 @@ void VEAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, Value <<= Info.TargetOffset; unsigned NumBytes = getFixupKindNumBytes(Fixup.getKind()); - unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the bits // from the fixup value. The Value has been "split up" into the // appropriate bitfields above. for (unsigned i = 0; i != NumBytes; ++i) { unsigned Idx = Endian == llvm::endianness::little ? i : (NumBytes - 1) - i; - Data[Offset + Idx] |= static_cast<uint8_t>((Value >> (i * 8)) & 0xff); + Data[Idx] |= static_cast<uint8_t>((Value >> (i * 8)) & 0xff); } } diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp index 837fd8e..eecef31 100644 --- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp +++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp @@ -39,7 +39,7 @@ public: MCFixupKindInfo getFixupKindInfo(MCFixupKind Kind) const override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, bool) override; + uint8_t *Data, uint64_t Value, bool) override; std::unique_ptr<MCObjectTargetWriter> createObjectTargetWriter() const override; @@ -80,8 +80,7 @@ bool WebAssemblyAsmBackend::writeNopData(raw_ostream &OS, uint64_t Count, void WebAssemblyAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, + const MCValue &Target, uint8_t *Data, uint64_t Value, bool IsResolved) { if (!IsResolved) Asm->getWriter().recordRelocation(F, Fixup, Target, Value); @@ -96,13 +95,13 @@ void WebAssemblyAsmBackend::applyFixup(const MCFragment &F, // Shift the value into position. Value <<= Info.TargetOffset; - unsigned Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the // bits from the fixup value. for (unsigned I = 0; I != NumBytes; ++I) - Data[Offset + I] |= uint8_t((Value >> (I * 8)) & 0xff); + Data[I] |= uint8_t((Value >> (I * 8)) & 0xff); } std::unique_ptr<MCObjectTargetWriter> diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.td b/llvm/lib/Target/WebAssembly/WebAssembly.td index a606209..089be5f 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.td +++ b/llvm/lib/Target/WebAssembly/WebAssembly.td @@ -49,6 +49,8 @@ def FeatureFP16 : SubtargetFeature<"fp16", "HasFP16", "true", "Enable FP16 instructions">; +def FeatureGC : SubtargetFeature<"gc", "HasGC", "true", "Enable wasm gc">; + def FeatureMultiMemory : SubtargetFeature<"multimemory", "HasMultiMemory", "true", "Enable multiple memories">; @@ -71,7 +73,6 @@ def FeatureReferenceTypes : SubtargetFeature<"reference-types", "HasReferenceTypes", "true", "Enable reference types">; -def FeatureGC : SubtargetFeature<"gc", "HasGC", "true", "Enable wasm gc">; def FeatureRelaxedSIMD : SubtargetFeature<"relaxed-simd", "SIMDLevel", "RelaxedSIMD", "Enable relaxed-simd instructions">; @@ -139,10 +140,10 @@ def : ProcessorModel<"lime1", NoSchedModel, def : ProcessorModel<"bleeding-edge", NoSchedModel, [FeatureAtomics, FeatureBulkMemory, FeatureBulkMemoryOpt, FeatureCallIndirectOverlong, FeatureExceptionHandling, - FeatureExtendedConst, FeatureFP16, FeatureMultiMemory, - FeatureMultivalue, FeatureMutableGlobals, + FeatureExtendedConst, FeatureFP16, FeatureGC, + FeatureMultiMemory, FeatureMultivalue, FeatureMutableGlobals, FeatureNontrappingFPToInt, FeatureRelaxedSIMD, - FeatureReferenceTypes, FeatureGC, FeatureSIMD128, + FeatureReferenceTypes, FeatureSIMD128, FeatureSignExt, FeatureTailCall]>; //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td index 2b632fd..13d048a 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td @@ -50,6 +50,9 @@ def HasFP16 : Predicate<"Subtarget->hasFP16()">, AssemblerPredicate<(all_of FeatureFP16), "fp16">; +def HasGC : Predicate<"Subtarget->hasGC()">, + AssemblerPredicate<(all_of FeatureGC), "gc">; + def HasMultiMemory : Predicate<"Subtarget->hasMultiMemory()">, AssemblerPredicate<(all_of FeatureMultiMemory), "multimemory">; @@ -76,9 +79,6 @@ def HasReferenceTypes : Predicate<"Subtarget->hasReferenceTypes()">, AssemblerPredicate<(all_of FeatureReferenceTypes), "reference-types">; -def HasGC : Predicate<"Subtarget->hasGC()">, - AssemblerPredicate<(all_of FeatureGC), "gc">; - def HasRelaxedSIMD : Predicate<"Subtarget->hasRelaxedSIMD()">, AssemblerPredicate<(all_of FeatureRelaxedSIMD), "relaxed-simd">; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h index f814274..2f88bbb 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h @@ -46,12 +46,12 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo { bool HasExceptionHandling = false; bool HasExtendedConst = false; bool HasFP16 = false; + bool HasGC = false; bool HasMultiMemory = false; bool HasMultivalue = false; bool HasMutableGlobals = false; bool HasNontrappingFPToInt = false; bool HasReferenceTypes = false; - bool HasGC = false; bool HasSignExt = false; bool HasTailCall = false; bool HasWideArithmetic = false; diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp index 7f9d474..1f02e56 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp @@ -174,8 +174,7 @@ public: std::optional<bool> evaluateFixup(const MCFragment &, MCFixup &, MCValue &, uint64_t &) override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool mayNeedRelaxation(unsigned Opcode, ArrayRef<MCOperand> Operands, const MCSubtargetInfo &STI) const override; @@ -676,9 +675,8 @@ std::optional<bool> X86AsmBackend::evaluateFixup(const MCFragment &, } void X86AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { // Force relocation when there is a specifier. This might be too conservative // - GAS doesn't emit a relocation for call local@plt; local:. if (Target.getSpecifier()) @@ -690,7 +688,7 @@ void X86AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, return; unsigned Size = getFixupKindSize(Kind); - assert(Fixup.getOffset() + Size <= Data.size() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + Size <= F.getSize() && "Invalid fixup offset!"); int64_t SignedValue = static_cast<int64_t>(Value); if (IsResolved && Fixup.isPCRel()) { @@ -710,7 +708,7 @@ void X86AsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, } for (unsigned i = 0; i != Size; ++i) - Data[Fixup.getOffset() + i] = uint8_t(Value >> (i * 8)); + Data[i] = uint8_t(Value >> (i * 8)); } bool X86AsmBackend::mayNeedRelaxation(unsigned Opcode, diff --git a/llvm/lib/Target/X86/X86FastISel.cpp b/llvm/lib/Target/X86/X86FastISel.cpp index 0ff7f23..067bd43 100644 --- a/llvm/lib/Target/X86/X86FastISel.cpp +++ b/llvm/lib/Target/X86/X86FastISel.cpp @@ -3673,6 +3673,12 @@ bool X86FastISel::fastLowerCall(CallLoweringInfo &CLI) { CLI.NumResultRegs = RVLocs.size(); CLI.Call = MIB; + // Add call site info for call graph section. + if (TM.Options.EmitCallGraphSection && CB && CB->isIndirectCall()) { + MachineFunction::CallSiteInfo CSInfo(*CB); + MF->addCallSiteInfo(CLI.Call, std::move(CSInfo)); + } + return true; } @@ -4042,6 +4048,8 @@ bool X86FastISel::tryToFoldLoadIntoMI(MachineInstr *MI, unsigned OpNo, MO.setReg(IndexReg); } + if (MI->isCall()) + FuncInfo.MF->moveAdditionalCallInfo(MI, Result); Result->addMemOperand(*FuncInfo.MF, createMachineMemOperandFor(LI)); Result->cloneInstrSymbols(*FuncInfo.MF, *MI); MachineBasicBlock::iterator I(MI); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 7244a6d..ce4c061 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23486,7 +23486,6 @@ static SDValue EmitCmp(SDValue Op0, SDValue Op1, X86::CondCode X86CC, } // Try to shrink i64 compares if the input has enough zero bits. - // TODO: Add sign-bits equivalent for isX86CCSigned(X86CC)? if (CmpVT == MVT::i64 && !isX86CCSigned(X86CC) && Op0.hasOneUse() && // Hacky way to not break CSE opportunities with sub. DAG.MaskedValueIsZero(Op1, APInt::getHighBitsSet(64, 32)) && @@ -23496,6 +23495,16 @@ static SDValue EmitCmp(SDValue Op0, SDValue Op1, X86::CondCode X86CC, Op1 = DAG.getNode(ISD::TRUNCATE, dl, CmpVT, Op1); } + // Try to shrink all i64 compares if the inputs are representable as signed + // i32. + if (CmpVT == MVT::i64 && + Op0.hasOneUse() && // Hacky way to not break CSE opportunities with sub. + DAG.ComputeNumSignBits(Op1) > 32 && DAG.ComputeNumSignBits(Op0) > 32) { + CmpVT = MVT::i32; + Op0 = DAG.getNode(ISD::TRUNCATE, dl, CmpVT, Op0); + Op1 = DAG.getNode(ISD::TRUNCATE, dl, CmpVT, Op1); + } + // 0-x == y --> x+y == 0 // 0-x != y --> x+y != 0 if (Op0.getOpcode() == ISD::SUB && isNullConstant(Op0.getOperand(0)) && @@ -58071,12 +58080,24 @@ static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) { Ops[3] = Op1.getOperand(0); Ops[4] = Op1.getOperand(1); } else if (Op1.getOpcode() == ISD::AND && Sub.getValue(0).use_empty()) { + SDValue Src = Op1; + SDValue Op10 = Op1.getOperand(0); + if (Op10.getOpcode() == ISD::XOR && isAllOnesConstant(Op10.getOperand(1))) { + // res, flags2 = sub 0, (and (xor X, -1), Y) + // cload/cstore ..., cond_ne, flag2 + // -> + // res, flags2 = sub 0, (and X, Y) + // cload/cstore ..., cond_e, flag2 + Src = DAG.getNode(ISD::AND, DL, Op1.getValueType(), Op10.getOperand(0), + Op1.getOperand(1)); + Ops[3] = DAG.getTargetConstant(X86::COND_E, DL, MVT::i8); + } // res, flags2 = sub 0, (and X, Y) - // cload/cstore ..., cond_ne, flag2 + // cload/cstore ..., cc, flag2 // -> // res, flags2 = cmp (and X, Y), 0 - // cload/cstore ..., cond_ne, flag2 - Ops[4] = DAG.getNode(X86ISD::CMP, DL, MVT::i32, Op1, Sub.getOperand(0)); + // cload/cstore ..., cc, flag2 + Ops[4] = DAG.getNode(X86ISD::CMP, DL, MVT::i32, Src, Sub.getOperand(0)); } else { return SDValue(); } diff --git a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp index b4639ac..5862c7e 100644 --- a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp +++ b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp @@ -2060,6 +2060,10 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, if (CallConv == CallingConv::X86_INTR) report_fatal_error("X86 interrupts may not be called directly"); + // Set type id for call site info. + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + CSInfo = MachineFunction::CallSiteInfo(*CB); + if (IsIndirectCall && !IsWin64 && M->getModuleFlag("import-call-optimization")) errorUnsupported(DAG, dl, diff --git a/llvm/lib/Target/Xtensa/MCTargetDesc/XtensaAsmBackend.cpp b/llvm/lib/Target/Xtensa/MCTargetDesc/XtensaAsmBackend.cpp index 9167794..08936ad 100644 --- a/llvm/lib/Target/Xtensa/MCTargetDesc/XtensaAsmBackend.cpp +++ b/llvm/lib/Target/Xtensa/MCTargetDesc/XtensaAsmBackend.cpp @@ -37,8 +37,7 @@ public: std::optional<bool> evaluateFixup(const MCFragment &, MCFixup &, MCValue &, uint64_t &) override; void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool writeNopData(raw_ostream &OS, uint64_t Count, const MCSubtargetInfo *STI) const override; @@ -153,9 +152,8 @@ std::optional<bool> XtensaAsmBackend::evaluateFixup(const MCFragment &F, } void XtensaAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { maybeAddReloc(F, Fixup, Target, Value, IsResolved); MCContext &Ctx = getContext(); MCFixupKindInfo Info = getFixupKindInfo(Fixup.getKind()); @@ -168,11 +166,10 @@ void XtensaAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, if (!Value) return; // Doesn't change encoding. - unsigned Offset = Fixup.getOffset(); unsigned FullSize = getSize(Fixup.getKind()); for (unsigned i = 0; i != FullSize; ++i) { - Data[Offset + i] |= uint8_t((Value >> (i * 8)) & 0xff); + Data[i] |= uint8_t((Value >> (i * 8)) & 0xff); } } diff --git a/llvm/lib/TargetParser/Host.cpp b/llvm/lib/TargetParser/Host.cpp index 78bd5b4..7e09d30 100644 --- a/llvm/lib/TargetParser/Host.cpp +++ b/llvm/lib/TargetParser/Host.cpp @@ -587,8 +587,9 @@ StringRef sys::detail::getHostCPUNameForBPF() { #endif } -#if defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || \ - defined(_M_X64) +#if (defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || \ + defined(_M_X64)) && \ + !defined(_M_ARM64EC) /// getX86CpuIDAndInfo - Execute the specified cpuid and return the 4 values in /// the specified arguments. If we can't run cpuid on the host, return true. @@ -1853,8 +1854,9 @@ VendorSignatures getVendorSignature(unsigned *MaxLeaf) { } // namespace llvm #endif -#if defined(__i386__) || defined(_M_IX86) || \ - defined(__x86_64__) || defined(_M_X64) +#if (defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || \ + defined(_M_X64)) && \ + !defined(_M_ARM64EC) StringMap<bool> sys::getHostCPUFeatures() { unsigned EAX = 0, EBX = 0, ECX = 0, EDX = 0; unsigned MaxLevel; @@ -2147,7 +2149,8 @@ StringMap<bool> sys::getHostCPUFeatures() { return Features; } -#elif defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64)) +#elif defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64) || \ + defined(__arm64ec__) || defined(_M_ARM64EC)) StringMap<bool> sys::getHostCPUFeatures() { StringMap<bool> Features; diff --git a/llvm/lib/TargetParser/TargetParser.cpp b/llvm/lib/TargetParser/TargetParser.cpp index e5c896f..e5d2e1c 100644 --- a/llvm/lib/TargetParser/TargetParser.cpp +++ b/llvm/lib/TargetParser/TargetParser.cpp @@ -444,8 +444,10 @@ void AMDGPU::fillAMDGPUFeatureMap(StringRef GPU, const Triple &T, Features["bitop3-insts"] = true; Features["prng-inst"] = true; Features["tanh-insts"] = true; + Features["tensor-cvt-lut-insts"] = true; Features["transpose-load-f4f6-insts"] = true; Features["bf16-trans-insts"] = true; + Features["bf16-cvt-insts"] = true; Features["fp8-conversion-insts"] = true; Features["fp8e5m3-insts"] = true; Features["permlane16-swap"] = true; diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 7af5ba4..40a7f80 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -458,29 +458,19 @@ static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, // Check if this array of constants represents a cttz table. // Iterate over the elements from \p Table by trying to find/match all // the numbers from 0 to \p InputBits that should represent cttz results. -static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, - uint64_t Shift, uint64_t InputBits) { - unsigned Length = Table.getNumElements(); - if (Length < InputBits || Length > InputBits * 2) - return false; - - APInt Mask = APInt::getBitsSetFrom(InputBits, Shift); - unsigned Matched = 0; - - for (unsigned i = 0; i < Length; i++) { - uint64_t Element = Table.getElementAsInteger(i); - if (Element >= InputBits) - continue; - - // Check if \p Element matches a concrete answer. It could fail for some - // elements that are never accessed, so we keep iterating over each element - // from the table. The number of matched elements should be equal to the - // number of potential right answers which is \p InputBits actually. - if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i) - Matched++; +static bool isCTTZTable(Constant *Table, const APInt &Mul, const APInt &Shift, + const APInt &AndMask, Type *AccessTy, + unsigned InputBits, const APInt &GEPIdxFactor, + const DataLayout &DL) { + for (unsigned Idx = 0; Idx < InputBits; Idx++) { + APInt Index = (APInt(InputBits, 1).shl(Idx) * Mul).lshr(Shift) & AndMask; + ConstantInt *C = dyn_cast_or_null<ConstantInt>( + ConstantFoldLoadFromConst(Table, AccessTy, Index * GEPIdxFactor, DL)); + if (!C || C->getValue() != Idx) + return false; } - return Matched == InputBits; + return true; } // Try to recognize table-based ctz implementation. @@ -495,6 +485,11 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, // this can be lowered to `cttz` instruction. // There is also a special case when the element is 0. // +// The (x & -x) sets the lowest non-zero bit to 1. The multiply is a de-bruijn +// sequence that contains each pattern of bits in it. The shift extracts +// the top bits after the multiply, and that index into the table should +// represent the number of trailing zeros in the original number. +// // Here are some examples or LLVM IR for a 64-bit target: // // CASE 1: @@ -536,8 +531,8 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, // i64 %shr // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 // -// All this can be lowered to @llvm.cttz.i32/64 intrinsic. -static bool tryToRecognizeTableBasedCttz(Instruction &I) { +// All these can be lowered to @llvm.cttz.i32/64 intrinsics. +static bool tryToRecognizeTableBasedCttz(Instruction &I, const DataLayout &DL) { LoadInst *LI = dyn_cast<LoadInst>(&I); if (!LI) return false; @@ -547,53 +542,47 @@ static bool tryToRecognizeTableBasedCttz(Instruction &I) { return false; GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand()); - if (!GEP || !GEP->hasNoUnsignedSignedWrap() || GEP->getNumIndices() != 2) - return false; - - if (!GEP->getSourceElementType()->isArrayTy()) - return false; - - uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements(); - if (ArraySize != 32 && ArraySize != 64) + if (!GEP || !GEP->hasNoUnsignedSignedWrap()) return false; GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand()); if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant()) return false; - ConstantDataArray *ConstData = - dyn_cast<ConstantDataArray>(GVTable->getInitializer()); - if (!ConstData) - return false; - - if (!match(GEP->idx_begin()->get(), m_ZeroInt())) + unsigned BW = DL.getIndexTypeSizeInBits(GEP->getType()); + APInt ModOffset(BW, 0); + SmallMapVector<Value *, APInt, 4> VarOffsets; + if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset) || + VarOffsets.size() != 1 || ModOffset != 0) return false; + auto [GepIdx, GEPScale] = VarOffsets.front(); - Value *Idx2 = std::next(GEP->idx_begin())->get(); Value *X1; - uint64_t MulConst, ShiftConst; - // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will - // probably fail for other (e.g. 32-bit) targets. - if (!match(Idx2, m_ZExtOrSelf( - m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)), - m_ConstantInt(MulConst)), - m_ConstantInt(ShiftConst))))) + const APInt *MulConst, *ShiftConst, *AndCst = nullptr; + // Check that the gep variable index is ((x & -x) * MulConst) >> ShiftConst. + // This might be extended to the pointer index type, and if the gep index type + // has been replaced with an i8 then a new And (and different ShiftConst) will + // be present. + auto MatchInner = m_LShr( + m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)), m_APInt(MulConst)), + m_APInt(ShiftConst)); + if (!match(GepIdx, m_CastOrSelf(MatchInner)) && + !match(GepIdx, m_CastOrSelf(m_And(MatchInner, m_APInt(AndCst))))) return false; unsigned InputBits = X1->getType()->getScalarSizeInBits(); - if (InputBits != 32 && InputBits != 64) - return false; - - // Shift should extract top 5..7 bits. - if (InputBits - Log2_32(InputBits) != ShiftConst && - InputBits - Log2_32(InputBits) - 1 != ShiftConst) + if (InputBits != 16 && InputBits != 32 && InputBits != 64 && InputBits != 128) return false; - if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits)) + if (!GEPScale.isIntN(InputBits) || + !isCTTZTable(GVTable->getInitializer(), *MulConst, *ShiftConst, + AndCst ? *AndCst : APInt::getAllOnes(InputBits), AccessType, + InputBits, GEPScale.zextOrTrunc(InputBits), DL)) return false; - auto ZeroTableElem = ConstData->getElementAsInteger(0); - bool DefinedForZero = ZeroTableElem == InputBits; + ConstantInt *ZeroTableElem = cast<ConstantInt>( + ConstantFoldLoadFromConst(GVTable->getInitializer(), AccessType, DL)); + bool DefinedForZero = ZeroTableElem->getZExtValue() == InputBits; IRBuilder<> B(LI); ConstantInt *BoolConst = B.getInt1(!DefinedForZero); @@ -607,8 +596,7 @@ static bool tryToRecognizeTableBasedCttz(Instruction &I) { // If the value in elem 0 isn't the same as InputBits, we still want to // produce the value from the table. auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0)); - auto Select = - B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz); + auto Select = B.CreateSelect(Cmp, B.CreateZExt(ZeroTableElem, XType), Cttz); // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target // it should be handled as: `cttz(x) & (typeSize - 1)`. @@ -1477,7 +1465,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, MadeChange |= foldGuardedFunnelShift(I, DT); MadeChange |= tryToRecognizePopCount(I); MadeChange |= tryToFPToSat(I, TTI); - MadeChange |= tryToRecognizeTableBasedCttz(I); + MadeChange |= tryToRecognizeTableBasedCttz(I, DL); MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); MadeChange |= foldPatternedLoads(I, DL); MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT); diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 64b33e4..ab906f9 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -1568,7 +1568,7 @@ private: if (DebugLoc SuspendLoc = S->getDebugLoc()) { std::string LabelName = ("__coro_resume_" + Twine(SuspendIndex)).str(); - DILocation &DILoc = *SuspendLoc.get(); + DILocation &DILoc = *SuspendLoc; DILabel *ResumeLabel = DBuilder.createLabel(DIS, LabelName, DILoc.getFile(), SuspendLoc.getLine(), SuspendLoc.getCol(), diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 3c24d2e..01da012 100644 --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -13404,7 +13404,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo { return indicatePessimisticFixpoint(); if (BinSize == 0) { - auto NewAllocationSize = std::optional<TypeSize>(TypeSize(0, false)); + auto NewAllocationSize = std::make_optional<TypeSize>(0, false); if (!changeAllocationSize(NewAllocationSize)) return ChangeStatus::UNCHANGED; return ChangeStatus::CHANGED; @@ -13422,8 +13422,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo { if (SizeOfBin >= *AllocationSize) return indicatePessimisticFixpoint(); - auto NewAllocationSize = - std::optional<TypeSize>(TypeSize(SizeOfBin * 8, false)); + auto NewAllocationSize = std::make_optional<TypeSize>(SizeOfBin * 8, false); if (!changeAllocationSize(NewAllocationSize)) return ChangeStatus::UNCHANGED; diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index 486205c..57844a1 100644 --- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -502,8 +502,7 @@ class LowerTypeTestsModule { uint8_t *exportTypeId(StringRef TypeId, const TypeIdLowering &TIL); TypeIdLowering importTypeId(StringRef TypeId); void importTypeTest(CallInst *CI); - void importFunction(Function *F, bool isJumpTableCanonical, - std::vector<GlobalAlias *> &AliasesToErase); + void importFunction(Function *F, bool isJumpTableCanonical); BitSetInfo buildBitSet(Metadata *TypeId, @@ -1103,9 +1102,8 @@ void LowerTypeTestsModule::maybeReplaceComdat(Function *F, // ThinLTO backend: the function F has a jump table entry; update this module // accordingly. isJumpTableCanonical describes the type of the jump table entry. -void LowerTypeTestsModule::importFunction( - Function *F, bool isJumpTableCanonical, - std::vector<GlobalAlias *> &AliasesToErase) { +void LowerTypeTestsModule::importFunction(Function *F, + bool isJumpTableCanonical) { assert(F->getType()->getAddressSpace() == 0); GlobalValue::VisibilityTypes Visibility = F->getVisibility(); @@ -1135,23 +1133,23 @@ void LowerTypeTestsModule::importFunction( } else { F->setName(Name + ".cfi"); maybeReplaceComdat(F, Name); - F->setLinkage(GlobalValue::ExternalLinkage); FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, F->getAddressSpace(), Name, &M); FDecl->setVisibility(Visibility); Visibility = GlobalValue::HiddenVisibility; - // Delete aliases pointing to this function, they'll be re-created in the - // merged output. Don't do it yet though because ScopedSaveAliaseesAndUsed - // will want to reset the aliasees first. + // Update aliases pointing to this function to also include the ".cfi" suffix, + // We expect the jump table entry to either point to the real function or an + // alias. Redirect all other users to the jump table entry. for (auto &U : F->uses()) { if (auto *A = dyn_cast<GlobalAlias>(U.getUser())) { + std::string AliasName = A->getName().str() + ".cfi"; Function *AliasDecl = Function::Create( F->getFunctionType(), GlobalValue::ExternalLinkage, F->getAddressSpace(), "", &M); AliasDecl->takeName(A); A->replaceAllUsesWith(AliasDecl); - AliasesToErase.push_back(A); + A->setName(AliasName); } } } @@ -2077,16 +2075,13 @@ bool LowerTypeTestsModule::lower() { Decls.push_back(&F); } - std::vector<GlobalAlias *> AliasesToErase; { ScopedSaveAliaseesAndUsed S(M); for (auto *F : Defs) - importFunction(F, /*isJumpTableCanonical*/ true, AliasesToErase); + importFunction(F, /*isJumpTableCanonical*/ true); for (auto *F : Decls) - importFunction(F, /*isJumpTableCanonical*/ false, AliasesToErase); + importFunction(F, /*isJumpTableCanonical*/ false); } - for (GlobalAlias *GA : AliasesToErase) - GA->eraseFromParent(); return true; } @@ -2137,6 +2132,18 @@ bool LowerTypeTestsModule::lower() { if (auto Alias = dyn_cast<AliasSummary>(RefGVS.get())) AddressTaken.insert(Alias->getAliaseeGUID()); } + auto IsAddressTaken = [&](GlobalValue::GUID GUID) { + if (AddressTaken.count(GUID)) + return true; + auto VI = ExportSummary->getValueInfo(GUID); + if (!VI) + return false; + for (auto &I : VI.getSummaryList()) + if (auto Alias = dyn_cast<AliasSummary>(I.get())) + if (AddressTaken.count(Alias->getAliaseeGUID())) + return true; + return false; + }; for (auto *FuncMD : CfiFunctionsMD->operands()) { assert(FuncMD->getNumOperands() >= 2); StringRef FunctionName = @@ -2153,7 +2160,7 @@ bool LowerTypeTestsModule::lower() { // have no live references (and are not exported with cross-DSO CFI.) if (!ExportSummary->isGUIDLive(GUID)) continue; - if (!AddressTaken.count(GUID)) { + if (!IsAddressTaken(GUID)) { if (!CrossDsoCfi || Linkage != CFL_Definition) continue; @@ -2227,6 +2234,43 @@ bool LowerTypeTestsModule::lower() { } } + struct AliasToCreate { + Function *Alias; + std::string TargetName; + }; + std::vector<AliasToCreate> AliasesToCreate; + + // Parse alias data to replace stand-in function declarations for aliases + // with an alias to the intended target. + if (ExportSummary) { + if (NamedMDNode *AliasesMD = M.getNamedMetadata("aliases")) { + for (auto *AliasMD : AliasesMD->operands()) { + SmallVector<Function *> Aliases; + for (Metadata *MD : AliasMD->operands()) { + auto *MDS = dyn_cast<MDString>(MD); + if (!MDS) + continue; + StringRef AliasName = MDS->getString(); + if (!ExportedFunctions.count(AliasName)) + continue; + auto *AliasF = M.getFunction(AliasName); + if (AliasF) + Aliases.push_back(AliasF); + } + + if (Aliases.empty()) + continue; + + for (unsigned I = 1; I != Aliases.size(); ++I) { + auto *AliasF = Aliases[I]; + ExportedFunctions.erase(AliasF->getName()); + AliasesToCreate.push_back( + {AliasF, std::string(Aliases[0]->getName())}); + } + } + } + } + DenseMap<GlobalObject *, GlobalTypeMember *> GlobalTypeMembers; for (GlobalObject &GO : M.global_objects()) { if (isa<GlobalVariable>(GO) && GO.isDeclarationForLinker()) @@ -2414,47 +2458,16 @@ bool LowerTypeTestsModule::lower() { allocateByteArrays(); - // Parse alias data to replace stand-in function declarations for aliases - // with an alias to the intended target. - if (ExportSummary) { - if (NamedMDNode *AliasesMD = M.getNamedMetadata("aliases")) { - for (auto *AliasMD : AliasesMD->operands()) { - assert(AliasMD->getNumOperands() >= 4); - StringRef AliasName = - cast<MDString>(AliasMD->getOperand(0))->getString(); - StringRef Aliasee = cast<MDString>(AliasMD->getOperand(1))->getString(); - - if (auto It = ExportedFunctions.find(Aliasee); - It == ExportedFunctions.end() || - It->second.Linkage != CFL_Definition || !M.getNamedAlias(Aliasee)) - continue; - - GlobalValue::VisibilityTypes Visibility = - static_cast<GlobalValue::VisibilityTypes>( - cast<ConstantAsMetadata>(AliasMD->getOperand(2)) - ->getValue() - ->getUniqueInteger() - .getZExtValue()); - bool Weak = - static_cast<bool>(cast<ConstantAsMetadata>(AliasMD->getOperand(3)) - ->getValue() - ->getUniqueInteger() - .getZExtValue()); - - auto *Alias = GlobalAlias::create("", M.getNamedAlias(Aliasee)); - Alias->setVisibility(Visibility); - if (Weak) - Alias->setLinkage(GlobalValue::WeakAnyLinkage); - - if (auto *F = M.getFunction(AliasName)) { - Alias->takeName(F); - F->replaceAllUsesWith(Alias); - F->eraseFromParent(); - } else { - Alias->setName(AliasName); - } - } - } + for (auto A : AliasesToCreate) { + auto *Target = M.getNamedValue(A.TargetName); + if (!isa<GlobalAlias>(Target)) + continue; + auto *AliasGA = GlobalAlias::create("", Target); + AliasGA->setVisibility(A.Alias->getVisibility()); + AliasGA->setLinkage(A.Alias->getLinkage()); + AliasGA->takeName(A.Alias); + A.Alias->replaceAllUsesWith(AliasGA); + A.Alias->eraseFromParent(); } // Emit .symver directives for exported functions, if they exist. diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp index c009c1e..b8c99f1 100644 --- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp +++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -99,6 +99,9 @@ STATISTIC(SkippedCallsCloning, "Number of calls skipped during cloning due to unexpected operand"); STATISTIC(MismatchedCloneAssignments, "Number of callsites assigned to call multiple non-matching clones"); +STATISTIC(TotalMergeInvokes, "Number of merge invocations for nodes"); +STATISTIC(TotalMergeIters, "Number of merge iterations for nodes"); +STATISTIC(MaxMergeIters, "Max merge iterations for nodes"); static cl::opt<std::string> DotFilePathPrefix( "memprof-dot-file-path-prefix", cl::init(""), cl::Hidden, @@ -109,6 +112,11 @@ static cl::opt<bool> ExportToDot("memprof-export-to-dot", cl::init(false), cl::Hidden, cl::desc("Export graph to dot files.")); +// TODO: Remove this option once new handling is validated more widely. +static cl::opt<bool> DoMergeIteration( + "memprof-merge-iteration", cl::init(true), cl::Hidden, + cl::desc("Iteratively apply merging on a node to catch new callers")); + // How much of the graph to export to dot. enum DotScope { All, // The full CCG graph. @@ -3995,7 +4003,7 @@ IndexCallsiteContextGraph::getAllocationCallType(const CallInfo &Call) const { void ModuleCallsiteContextGraph::updateCall(CallInfo &CallerCall, FuncInfo CalleeFunc) { - auto *CurF = cast<CallBase>(CallerCall.call())->getCalledFunction(); + auto *CurF = getCalleeFunc(CallerCall.call()); auto NewCalleeCloneNo = CalleeFunc.cloneNo(); if (isMemProfClone(*CurF)) { // If we already assigned this callsite to call a specific non-default @@ -4191,16 +4199,36 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::mergeClones( if (!Inserted.second) return; - // Make a copy since the recursive call may move a caller edge to a new - // callee, messing up the iterator. - auto CallerEdges = Node->CallerEdges; - for (auto CallerEdge : CallerEdges) { - // Skip any caller edge moved onto a different callee during recursion. - if (CallerEdge->Callee != Node) - continue; - mergeClones(CallerEdge->Caller, Visited, ContextIdToAllocationNode); + // Iteratively perform merging on this node to handle new caller nodes created + // during the recursive traversal. We could do something more elegant such as + // maintain a worklist, but this is a simple approach that doesn't cause a + // measureable compile time effect, as most nodes don't have many caller + // edges to check. + bool FoundUnvisited = true; + unsigned Iters = 0; + while (FoundUnvisited) { + Iters++; + FoundUnvisited = false; + // Make a copy since the recursive call may move a caller edge to a new + // callee, messing up the iterator. + auto CallerEdges = Node->CallerEdges; + for (auto CallerEdge : CallerEdges) { + // Skip any caller edge moved onto a different callee during recursion. + if (CallerEdge->Callee != Node) + continue; + // If we found an unvisited caller, note that we should check the caller + // edges again as mergeClones may add or change caller nodes. + if (DoMergeIteration && !Visited.contains(CallerEdge->Caller)) + FoundUnvisited = true; + mergeClones(CallerEdge->Caller, Visited, ContextIdToAllocationNode); + } } + TotalMergeInvokes++; + TotalMergeIters += Iters; + if (Iters > MaxMergeIters) + MaxMergeIters = Iters; + // Merge for this node after we handle its callers. mergeNodeCalleeClones(Node, Visited, ContextIdToAllocationNode); } diff --git a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index e276376..4387c38 100644 --- a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -384,6 +384,10 @@ void splitAndWriteThinLTOBitcode( for (auto &F : M) if ((!F.hasLocalLinkage() || F.hasAddressTaken()) && HasTypeMetadata(&F)) CfiFunctions.insert(&F); + for (auto &A : M.aliases()) + if (auto *F = dyn_cast<Function>(A.getAliasee())) + if (HasTypeMetadata(F)) + CfiFunctions.insert(&A); // Remove all globals with type metadata, globals with comdats that live in // MergedM, and aliases pointing to such globals from the thin LTO module. @@ -403,12 +407,12 @@ void splitAndWriteThinLTOBitcode( auto &Ctx = MergedM->getContext(); SmallVector<MDNode *, 8> CfiFunctionMDs; for (auto *V : CfiFunctions) { - Function &F = *cast<Function>(V); + Function &F = *cast<Function>(V->getAliaseeObject()); SmallVector<MDNode *, 2> Types; F.getMetadata(LLVMContext::MD_type, Types); SmallVector<Metadata *, 4> Elts; - Elts.push_back(MDString::get(Ctx, F.getName())); + Elts.push_back(MDString::get(Ctx, V->getName())); CfiFunctionLinkage Linkage; if (lowertypetests::isJumpTableCanonical(&F)) Linkage = CFL_Definition; @@ -428,29 +432,24 @@ void splitAndWriteThinLTOBitcode( NMD->addOperand(MD); } - SmallVector<MDNode *, 8> FunctionAliases; + MapVector<Function *, std::vector<GlobalAlias *>> FunctionAliases; for (auto &A : M.aliases()) { if (!isa<Function>(A.getAliasee())) continue; auto *F = cast<Function>(A.getAliasee()); - - Metadata *Elts[] = { - MDString::get(Ctx, A.getName()), - MDString::get(Ctx, F->getName()), - ConstantAsMetadata::get( - ConstantInt::get(Type::getInt8Ty(Ctx), A.getVisibility())), - ConstantAsMetadata::get( - ConstantInt::get(Type::getInt8Ty(Ctx), A.isWeakForLinker())), - }; - - FunctionAliases.push_back(MDTuple::get(Ctx, Elts)); + FunctionAliases[F].push_back(&A); } if (!FunctionAliases.empty()) { NamedMDNode *NMD = MergedM->getOrInsertNamedMetadata("aliases"); - for (auto *MD : FunctionAliases) - NMD->addOperand(MD); + for (auto &Alias : FunctionAliases) { + SmallVector<Metadata *> Elts; + Elts.push_back(MDString::get(Ctx, Alias.first->getName())); + for (auto *A : Alias.second) + Elts.push_back(MDString::get(Ctx, A->getName())); + NMD->addOperand(MDTuple::get(Ctx, Elts)); + } } SmallVector<MDNode *, 8> Symvers; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index b231c04..d7971e8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -11,10 +11,13 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/FloatingPointPredicateUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/ConstantRange.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" @@ -3589,6 +3592,154 @@ static Value *foldOrOfInversions(BinaryOperator &I, return nullptr; } +/// Match \p V as "shufflevector -> bitcast" or "extractelement -> zext -> shl" +/// patterns, which extract vector elements and pack them in the same relative +/// positions. +/// +/// \p Vec is the underlying vector being extracted from. +/// \p Mask is a bitmask identifying which packed elements are obtained from the +/// vector. +/// \p VecOffset is the vector element corresponding to index 0 of the +/// mask. +static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec, + int64_t &VecOffset, + SmallBitVector &Mask, + const DataLayout &DL) { + static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) { + ShlAmt = 0; + return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base); + }; + + // First try to match extractelement -> zext -> shl + uint64_t VecIdx, ShlAmt; + if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt( + m_Value(Vec), m_ConstantInt(VecIdx))), + ShlAmt))) { + auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + if (!VecTy) + return false; + auto *EltTy = dyn_cast<IntegerType>(VecTy->getElementType()); + if (!EltTy) + return false; + + const unsigned EltBitWidth = EltTy->getBitWidth(); + const unsigned TargetBitWidth = V->getType()->getIntegerBitWidth(); + if (TargetBitWidth % EltBitWidth != 0 || ShlAmt % EltBitWidth != 0) + return false; + const unsigned TargetEltWidth = TargetBitWidth / EltBitWidth; + const unsigned ShlEltAmt = ShlAmt / EltBitWidth; + + const unsigned MaskIdx = + DL.isLittleEndian() ? ShlEltAmt : TargetEltWidth - ShlEltAmt - 1; + + VecOffset = static_cast<int64_t>(VecIdx) - static_cast<int64_t>(MaskIdx); + Mask.resize(TargetEltWidth); + Mask.set(MaskIdx); + return true; + } + + // Now try to match a bitcasted subvector. + Instruction *SrcVecI; + if (!match(V, m_BitCast(m_Instruction(SrcVecI)))) + return false; + + auto *SrcTy = dyn_cast<FixedVectorType>(SrcVecI->getType()); + if (!SrcTy) + return false; + + Mask.resize(SrcTy->getNumElements()); + + // First check for a subvector obtained from a shufflevector. + if (isa<ShuffleVectorInst>(SrcVecI)) { + Constant *ConstVec; + ArrayRef<int> ShuffleMask; + if (!match(SrcVecI, m_Shuffle(m_Value(Vec), m_Constant(ConstVec), + m_Mask(ShuffleMask)))) + return false; + + auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + if (!VecTy) + return false; + + const unsigned NumVecElts = VecTy->getNumElements(); + bool FoundVecOffset = false; + for (unsigned Idx = 0; Idx < ShuffleMask.size(); ++Idx) { + if (ShuffleMask[Idx] == PoisonMaskElem) + return false; + const unsigned ShuffleIdx = ShuffleMask[Idx]; + if (ShuffleIdx >= NumVecElts) { + const unsigned ConstIdx = ShuffleIdx - NumVecElts; + auto *ConstElt = + dyn_cast<ConstantInt>(ConstVec->getAggregateElement(ConstIdx)); + if (!ConstElt || !ConstElt->isNullValue()) + return false; + continue; + } + + if (FoundVecOffset) { + if (VecOffset + Idx != ShuffleIdx) + return false; + } else { + if (ShuffleIdx < Idx) + return false; + VecOffset = ShuffleIdx - Idx; + FoundVecOffset = true; + } + Mask.set(Idx); + } + return FoundVecOffset; + } + + // Check for a subvector obtained as an (insertelement V, 0, idx) + uint64_t InsertIdx; + if (!match(SrcVecI, + m_InsertElt(m_Value(Vec), m_Zero(), m_ConstantInt(InsertIdx)))) + return false; + + auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + if (!VecTy) + return false; + VecOffset = 0; + bool AlreadyInsertedMaskedElt = Mask.test(InsertIdx); + Mask.set(); + if (!AlreadyInsertedMaskedElt) + Mask.reset(InsertIdx); + return true; +} + +/// Try to fold the join of two scalar integers whose contents are packed +/// elements of the same vector. +static Instruction *foldIntegerPackFromVector(Instruction &I, + InstCombiner::BuilderTy &Builder, + const DataLayout &DL) { + assert(I.getOpcode() == Instruction::Or); + Value *LhsVec, *RhsVec; + int64_t LhsVecOffset, RhsVecOffset; + SmallBitVector Mask; + if (!matchSubIntegerPackFromVector(I.getOperand(0), LhsVec, LhsVecOffset, + Mask, DL)) + return nullptr; + if (!matchSubIntegerPackFromVector(I.getOperand(1), RhsVec, RhsVecOffset, + Mask, DL)) + return nullptr; + if (LhsVec != RhsVec || LhsVecOffset != RhsVecOffset) + return nullptr; + + // Convert into shufflevector -> bitcast; + const unsigned ZeroVecIdx = + cast<FixedVectorType>(LhsVec->getType())->getNumElements(); + SmallVector<int> ShuffleMask(Mask.size(), ZeroVecIdx); + for (unsigned Idx : Mask.set_bits()) { + assert(LhsVecOffset + Idx >= 0); + ShuffleMask[Idx] = LhsVecOffset + Idx; + } + + Value *MaskedVec = Builder.CreateShuffleVector( + LhsVec, Constant::getNullValue(LhsVec->getType()), ShuffleMask, + I.getName() + ".v"); + return CastInst::Create(Instruction::BitCast, MaskedVec, I.getType()); +} + // A decomposition of ((X & Mask) * Factor). The NUW / NSW bools // track these properities for preservation. Note that we can decompose // equivalent select form of this expression (e.g. (!(X & Mask) ? 0 : Mask * @@ -3766,6 +3917,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *X = foldComplexAndOrPatterns(I, Builder)) return X; + if (Instruction *X = foldIntegerPackFromVector(I, Builder, DL)) + return X; + // (A & B) | (C & D) -> A ^ D where A == ~C && B == ~D // (A & B) | (C & D) -> A ^ C where A == ~D && B == ~C if (Value *V = foldOrOfInversions(I, Builder)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 1b78ace..47e017e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3891,16 +3891,20 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } - // Try to fold intrinsic into select operands. This is legal if: + // Try to fold intrinsic into select/phi operands. This is legal if: // * The intrinsic is speculatable. // * The select condition is not a vector, or the intrinsic does not // perform cross-lane operations. if (isSafeToSpeculativelyExecuteWithVariableReplaced(&CI) && isNotCrossLaneOperation(II)) - for (Value *Op : II->args()) + for (Value *Op : II->args()) { if (auto *Sel = dyn_cast<SelectInst>(Op)) if (Instruction *R = FoldOpIntoSelect(*II, Sel)) return R; + if (auto *Phi = dyn_cast<PHINode>(Op)) + if (Instruction *R = foldOpIntoPhi(*II, Phi)) + return R; + } if (Instruction *Shuf = foldShuffledIntrinsicOperands(II)) return Shuf; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index da9b126..b268fea 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -163,6 +163,11 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( LaterIndices.push_back(IdxVal); } + Value *Idx = GEP->getOperand(2); + // If the index type is non-canonical, wait for it to be canonicalized. + if (Idx->getType() != DL.getIndexType(GEP->getType())) + return nullptr; + enum { Overdefined = -3, Undefined = -2 }; // Variables for our state machines. @@ -290,17 +295,6 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( // Now that we've scanned the entire array, emit our new comparison(s). We // order the state machines in complexity of the generated code. - Value *Idx = GEP->getOperand(2); - - // If the index is larger than the pointer offset size of the target, truncate - // the index down like the GEP would do implicitly. We don't have to do this - // for an inbounds GEP because the index can't be out of range. - if (!GEP->isInBounds()) { - Type *PtrIdxTy = DL.getIndexType(GEP->getType()); - unsigned OffsetSize = PtrIdxTy->getIntegerBitWidth(); - if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > OffsetSize) - Idx = Builder.CreateTrunc(Idx, PtrIdxTy); - } // If inbounds keyword is not present, Idx * ElementSize can overflow. // Let's assume that ElementSize is 2 and the wanted value is at offset 0. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 00b877b..fe0f308 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -462,6 +462,13 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { return ScalarPHI; } + // If SrcVec is a subvector starting at index 0, extract from the + // wider source vector + Value *V; + if (match(SrcVec, + m_Intrinsic<Intrinsic::vector_extract>(m_Value(V), m_Zero()))) + return ExtractElementInst::Create(V, Index); + // TODO come up with a n-ary matcher that subsumes both unary and // binary matchers. UnaryOperator *UO; diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 873b104..5ee3bb1 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1994,6 +1994,8 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN, } Clone = InsertNewInstBefore(Clone, OpBB->getTerminator()->getIterator()); Clones.insert({OpBB, Clone}); + // We may have speculated the instruction. + Clone->dropUBImplyingAttrsAndMetadata(); } NewPhiValues[OpIndex] = Clone; @@ -2009,12 +2011,17 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN, NewPN->addIncoming(NewPhiValues[i], PN->getIncomingBlock(i)); if (IdenticalUsers) { - for (User *U : make_early_inc_range(PN->users())) { + // Collect and deduplicate users up-front to avoid iterator invalidation. + SmallSetVector<Instruction *, 4> ToReplace; + for (User *U : PN->users()) { Instruction *User = cast<Instruction>(U); if (User == &I) continue; - replaceInstUsesWith(*User, NewPN); - eraseInstFromFunction(*User); + ToReplace.insert(User); + } + for (Instruction *I : ToReplace) { + replaceInstUsesWith(*I, NewPN); + eraseInstFromFunction(*I); } OneUse = true; } @@ -2652,9 +2659,18 @@ static Instruction *canonicalizeGEPOfConstGEPI8(GetElementPtrInst &GEP, APInt NewOffset = TypeSize * *C2 + *C1; if (NewOffset.isZero() || (Src->hasOneUse() && GEP.getOperand(1)->hasOneUse())) { + GEPNoWrapFlags Flags = GEPNoWrapFlags::none(); + if (GEP.hasNoUnsignedWrap() && + cast<GEPOperator>(Src)->hasNoUnsignedWrap() && + match(GEP.getOperand(1), m_NUWAddLike(m_Value(), m_Value()))) { + Flags |= GEPNoWrapFlags::noUnsignedWrap(); + if (GEP.isInBounds() && cast<GEPOperator>(Src)->isInBounds()) + Flags |= GEPNoWrapFlags::inBounds(); + } + Value *GEPConst = - IC.Builder.CreatePtrAdd(Base, IC.Builder.getInt(NewOffset)); - return GetElementPtrInst::Create(BaseType, GEPConst, VarIndex); + IC.Builder.CreatePtrAdd(Base, IC.Builder.getInt(NewOffset), "", Flags); + return GetElementPtrInst::Create(BaseType, GEPConst, VarIndex, Flags); } return nullptr; @@ -3182,7 +3198,16 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { // If we are using a wider index than needed for this platform, shrink // it to what we need. If narrower, sign-extend it to what we need. // This explicit cast can make subsequent optimizations more obvious. - *I = Builder.CreateIntCast(*I, NewIndexType, true); + if (IndexTy->getScalarSizeInBits() < + NewIndexType->getScalarSizeInBits()) { + if (GEP.hasNoUnsignedWrap() && GEP.hasNoUnsignedSignedWrap()) + *I = Builder.CreateZExt(*I, NewIndexType, "", /*IsNonNeg=*/true); + else + *I = Builder.CreateSExt(*I, NewIndexType); + } else { + *I = Builder.CreateTrunc(*I, NewIndexType, "", GEP.hasNoUnsignedWrap(), + GEP.hasNoUnsignedSignedWrap()); + } MadeChange = true; } } @@ -3205,6 +3230,14 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { return replaceInstUsesWith(GEP, NewGEP); } + // Strip trailing zero indices. + auto *LastIdx = dyn_cast<Constant>(Indices.back()); + if (LastIdx && LastIdx->isNullValue() && !LastIdx->getType()->isVectorTy()) { + return replaceInstUsesWith( + GEP, Builder.CreateGEP(GEP.getSourceElementType(), PtrOp, + drop_end(Indices), "", GEP.getNoWrapFlags())); + } + // Scalarize vector operands; prefer splat-of-gep.as canonical form. // Note that this looses information about undef lanes; we run it after // demanded bits to partially mitigate that loss. diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 4e5a8d1..bcb90d6 100644 --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -160,6 +160,16 @@ static cl::opt<bool> ClGenerateTagsWithCalls( static cl::opt<bool> ClGlobals("hwasan-globals", cl::desc("Instrument globals"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClAllGlobals( + "hwasan-all-globals", + cl::desc( + "Instrument globals, even those within user-defined sections. Warning: " + "This may break existing code which walks globals via linker-generated " + "symbols, expects certain globals to be contiguous with each other, or " + "makes other assumptions which are invalidated by HWASan " + "instrumentation."), + cl::Hidden, cl::init(false)); + static cl::opt<int> ClMatchAllTag( "hwasan-match-all-tag", cl::desc("don't report bad accesses via pointers with this tag"), @@ -681,11 +691,11 @@ void HWAddressSanitizer::initializeModule() { !CompileKernel && !UsePageAliases && optOr(ClGlobals, NewRuntime); if (!CompileKernel) { - createHwasanCtorComdat(); - if (InstrumentGlobals) instrumentGlobals(); + createHwasanCtorComdat(); + bool InstrumentPersonalityFunctions = optOr(ClInstrumentPersonalityFunctions, NewRuntime); if (InstrumentPersonalityFunctions) @@ -1772,11 +1782,17 @@ void HWAddressSanitizer::instrumentGlobals() { if (GV.hasCommonLinkage()) continue; - // Globals with custom sections may be used in __start_/__stop_ enumeration, - // which would be broken both by adding tags and potentially by the extra - // padding/alignment that we insert. - if (GV.hasSection()) - continue; + if (ClAllGlobals) { + // Avoid instrumenting intrinsic global variables. + if (GV.getSection() == "llvm.metadata") + continue; + } else { + // Globals with custom sections may be used in __start_/__stop_ + // enumeration, which would be broken both by adding tags and potentially + // by the extra padding/alignment that we insert. + if (GV.hasSection()) + continue; + } Globals.push_back(&GV); } diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index df31f07..54d9a83 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -4769,6 +4769,79 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } + // Approximately handle AVX Galois Field Affine Transformation + // + // e.g., + // <16 x i8> @llvm.x86.vgf2p8affineqb.128(<16 x i8>, <16 x i8>, i8) + // <32 x i8> @llvm.x86.vgf2p8affineqb.256(<32 x i8>, <32 x i8>, i8) + // <64 x i8> @llvm.x86.vgf2p8affineqb.512(<64 x i8>, <64 x i8>, i8) + // Out A x b + // where A and x are packed matrices, b is a vector, + // Out = A * x + b in GF(2) + // + // Multiplication in GF(2) is equivalent to bitwise AND. However, the matrix + // computation also includes a parity calculation. + // + // For the bitwise AND of bits V1 and V2, the exact shadow is: + // Out_Shadow = (V1_Shadow & V2_Shadow) + // | (V1 & V2_Shadow) + // | (V1_Shadow & V2 ) + // + // We approximate the shadow of gf2p8affineqb using: + // Out_Shadow = gf2p8affineqb(x_Shadow, A_shadow, 0) + // | gf2p8affineqb(x, A_shadow, 0) + // | gf2p8affineqb(x_Shadow, A, 0) + // | set1_epi8(b_Shadow) + // + // This approximation has false negatives: if an intermediate dot-product + // contains an even number of 1's, the parity is 0. + // It has no false positives. + void handleAVXGF2P8Affine(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + + assert(I.arg_size() == 3); + Value *A = I.getOperand(0); + Value *X = I.getOperand(1); + Value *B = I.getOperand(2); + + assert(isFixedIntVector(A)); + assert(cast<VectorType>(A->getType()) + ->getElementType() + ->getScalarSizeInBits() == 8); + + assert(A->getType() == X->getType()); + + assert(B->getType()->isIntegerTy()); + assert(B->getType()->getScalarSizeInBits() == 8); + + assert(I.getType() == A->getType()); + + Value *AShadow = getShadow(A); + Value *XShadow = getShadow(X); + Value *BZeroShadow = getCleanShadow(B); + + CallInst *AShadowXShadow = IRB.CreateIntrinsic( + I.getType(), I.getIntrinsicID(), {XShadow, AShadow, BZeroShadow}); + CallInst *AShadowX = IRB.CreateIntrinsic(I.getType(), I.getIntrinsicID(), + {X, AShadow, BZeroShadow}); + CallInst *XShadowA = IRB.CreateIntrinsic(I.getType(), I.getIntrinsicID(), + {XShadow, A, BZeroShadow}); + + unsigned NumElements = cast<FixedVectorType>(I.getType())->getNumElements(); + Value *BShadow = getShadow(B); + Value *BBroadcastShadow = getCleanShadow(AShadow); + // There is no LLVM IR intrinsic for _mm512_set1_epi8. + // This loop generates a lot of LLVM IR, which we expect that CodeGen will + // lower appropriately (e.g., VPBROADCASTB). + // Besides, b is often a constant, in which case it is fully initialized. + for (unsigned i = 0; i < NumElements; i++) + BBroadcastShadow = IRB.CreateInsertElement(BBroadcastShadow, BShadow, i); + + setShadow(&I, IRB.CreateOr( + {AShadowXShadow, AShadowX, XShadowA, BBroadcastShadow})); + setOriginForNaryOp(I); + } + // Handle Arm NEON vector load intrinsics (vld*). // // The WithLane instructions (ld[234]lane) are similar to: @@ -5604,6 +5677,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { break; } + // AVX Galois Field New Instructions + case Intrinsic::x86_vgf2p8affineqb_128: + case Intrinsic::x86_vgf2p8affineqb_256: + case Intrinsic::x86_vgf2p8affineqb_512: + handleAVXGF2P8Affine(I); + break; + case Intrinsic::fshl: case Intrinsic::fshr: handleFunnelShift(I); diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index e706a6f..deff79b 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -237,8 +237,7 @@ class InductiveRangeCheckElimination { DominatorTree &DT; LoopInfo &LI; - using GetBFIFunc = - std::optional<llvm::function_ref<llvm::BlockFrequencyInfo &()>>; + using GetBFIFunc = llvm::function_ref<llvm::BlockFrequencyInfo &()>; GetBFIFunc GetBFI; // Returns the estimated number of iterations based on block frequency info if @@ -249,7 +248,7 @@ class InductiveRangeCheckElimination { public: InductiveRangeCheckElimination(ScalarEvolution &SE, BranchProbabilityInfo *BPI, DominatorTree &DT, - LoopInfo &LI, GetBFIFunc GetBFI = std::nullopt) + LoopInfo &LI, GetBFIFunc GetBFI = nullptr) : SE(SE), BPI(BPI), DT(DT), LI(LI), GetBFI(GetBFI) {} bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop); @@ -959,7 +958,7 @@ PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) { std::optional<uint64_t> InductiveRangeCheckElimination::estimatedTripCount(const Loop &L) { if (GetBFI) { - BlockFrequencyInfo &BFI = (*GetBFI)(); + BlockFrequencyInfo &BFI = GetBFI(); uint64_t hFreq = BFI.getBlockFreq(L.getHeader()).getFrequency(); uint64_t phFreq = BFI.getBlockFreq(L.getLoopPreheader()).getFrequency(); if (phFreq == 0 || hFreq == 0) diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index 68094c3..c3f80f9 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -2508,6 +2508,12 @@ static bool hoistGEP(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, if (!GEP) return false; + // Do not try to hoist a constant GEP out of the loop via reassociation. + // Constant GEPs can often be folded into addressing modes, and reassociating + // them may inhibit CSE of a common base. + if (GEP->hasAllConstantIndices()) + return false; + auto *Src = dyn_cast<GetElementPtrInst>(GEP->getPointerOperand()); if (!Src || !Src->hasOneUse() || !L.contains(Src)) return false; diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index f3e992c..04039b8 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -1009,7 +1009,8 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. - LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr); + LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr, + &AR.AC); for (Loop *InnerLoop : LN.getLoops()) { auto *OuterLoop = InnerLoop->getParentLoop(); if (!OuterLoop) diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 4f2bfb0..448dc2b 100644 --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -551,7 +551,7 @@ PreservedAnalyses LoopVersioningLICMPass::run(Loop &L, LoopAnalysisManager &AM, const Function *F = L.getHeader()->getParent(); OptimizationRemarkEmitter ORE(F); - LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr, nullptr); + LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr, nullptr, &LAR.AC); if (!LoopVersioningLICM(AA, SE, &ORE, LAIs, LAR.LI, &L).run(DT)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 320b792..6ffe841 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -79,8 +79,7 @@ // ld.global.f32 %f4, [%rl6+132]; // much better // // Another improvement enabled by the LowerGEP flag is to lower a GEP with -// multiple indices to either multiple GEPs with a single index or arithmetic -// operations (depending on whether the target uses alias analysis in codegen). +// multiple indices to multiple GEPs with a single index. // Such transformation can have following benefits: // (1) It can always extract constants in the indices of structure type. // (2) After such Lowering, there are more optimization opportunities such as @@ -88,59 +87,33 @@ // // E.g. The following GEPs have multiple indices: // BB1: -// %p = getelementptr [10 x %struct]* %ptr, i64 %i, i64 %j1, i32 3 +// %p = getelementptr [10 x %struct], ptr %ptr, i64 %i, i64 %j1, i32 3 // load %p // ... // BB2: -// %p2 = getelementptr [10 x %struct]* %ptr, i64 %i, i64 %j1, i32 2 +// %p2 = getelementptr [10 x %struct], ptr %ptr, i64 %i, i64 %j1, i32 2 // load %p2 // ... // // We can not do CSE to the common part related to index "i64 %i". Lowering // GEPs can achieve such goals. -// If the target does not use alias analysis in codegen, this pass will -// lower a GEP with multiple indices into arithmetic operations: -// BB1: -// %1 = ptrtoint [10 x %struct]* %ptr to i64 ; CSE opportunity -// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity -// %3 = add i64 %1, %2 ; CSE opportunity -// %4 = mul i64 %j1, length_of_struct -// %5 = add i64 %3, %4 -// %6 = add i64 %3, struct_field_3 ; Constant offset -// %p = inttoptr i64 %6 to i32* -// load %p -// ... -// BB2: -// %7 = ptrtoint [10 x %struct]* %ptr to i64 ; CSE opportunity -// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity -// %9 = add i64 %7, %8 ; CSE opportunity -// %10 = mul i64 %j2, length_of_struct -// %11 = add i64 %9, %10 -// %12 = add i64 %11, struct_field_2 ; Constant offset -// %p = inttoptr i64 %12 to i32* -// load %p2 -// ... // -// If the target uses alias analysis in codegen, this pass will lower a GEP -// with multiple indices into multiple GEPs with a single index: +// This pass will lower a GEP with multiple indices into multiple GEPs with a +// single index: // BB1: -// %1 = bitcast [10 x %struct]* %ptr to i8* ; CSE opportunity -// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity -// %3 = getelementptr i8* %1, i64 %2 ; CSE opportunity +// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %3 = getelementptr i8, ptr %ptr, i64 %2 ; CSE opportunity // %4 = mul i64 %j1, length_of_struct -// %5 = getelementptr i8* %3, i64 %4 -// %6 = getelementptr i8* %5, struct_field_3 ; Constant offset -// %p = bitcast i8* %6 to i32* +// %5 = getelementptr i8, ptr %3, i64 %4 +// %p = getelementptr i8, ptr %5, struct_field_3 ; Constant offset // load %p // ... // BB2: -// %7 = bitcast [10 x %struct]* %ptr to i8* ; CSE opportunity -// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity -// %9 = getelementptr i8* %7, i64 %8 ; CSE opportunity +// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %9 = getelementptr i8, ptr %ptr, i64 %8 ; CSE opportunity // %10 = mul i64 %j2, length_of_struct -// %11 = getelementptr i8* %9, i64 %10 -// %12 = getelementptr i8* %11, struct_field_2 ; Constant offset -// %p2 = bitcast i8* %12 to i32* +// %11 = getelementptr i8, ptr %9, i64 %10 +// %p2 = getelementptr i8, ptr %11, struct_field_2 ; Constant offset // load %p2 // ... // @@ -408,16 +381,6 @@ private: void lowerToSingleIndexGEPs(GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset); - /// Lower a GEP with multiple indices into ptrtoint+arithmetics+inttoptr form. - /// Function splitGEP already split the original GEP into a variadic part and - /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the - /// variadic part into a set of arithmetic operations and applies - /// AccumulativeByteOffset to it. - /// \p Variadic The variadic part of the original GEP. - /// \p AccumulativeByteOffset The constant offset. - void lowerToArithmetics(GetElementPtrInst *Variadic, - int64_t AccumulativeByteOffset); - /// Finds the constant offset within each index and accumulates them. If /// LowerGEP is true, it finds in indices of both sequential and structure /// types, otherwise it only finds in sequential indices. The output @@ -951,55 +914,6 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( Variadic->eraseFromParent(); } -void -SeparateConstOffsetFromGEP::lowerToArithmetics(GetElementPtrInst *Variadic, - int64_t AccumulativeByteOffset) { - IRBuilder<> Builder(Variadic); - Type *IntPtrTy = DL->getIntPtrType(Variadic->getType()); - assert(IntPtrTy == DL->getIndexType(Variadic->getType()) && - "Pointer type must match index type for arithmetic-based lowering of " - "split GEPs"); - - Value *ResultPtr = Builder.CreatePtrToInt(Variadic->getOperand(0), IntPtrTy); - gep_type_iterator GTI = gep_type_begin(*Variadic); - // Create ADD/SHL/MUL arithmetic operations for each sequential indices. We - // don't create arithmetics for structure indices, as they are accumulated - // in the constant offset index. - for (unsigned I = 1, E = Variadic->getNumOperands(); I != E; ++I, ++GTI) { - if (GTI.isSequential()) { - Value *Idx = Variadic->getOperand(I); - // Skip zero indices. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) - if (CI->isZero()) - continue; - - APInt ElementSize = APInt(IntPtrTy->getIntegerBitWidth(), - GTI.getSequentialElementStride(*DL)); - // Scale the index by element size. - if (ElementSize != 1) { - if (ElementSize.isPowerOf2()) { - Idx = Builder.CreateShl( - Idx, ConstantInt::get(IntPtrTy, ElementSize.logBase2())); - } else { - Idx = Builder.CreateMul(Idx, ConstantInt::get(IntPtrTy, ElementSize)); - } - } - // Create an ADD for each index. - ResultPtr = Builder.CreateAdd(ResultPtr, Idx); - } - } - - // Create an ADD for the constant offset index. - if (AccumulativeByteOffset != 0) { - ResultPtr = Builder.CreateAdd( - ResultPtr, ConstantInt::get(IntPtrTy, AccumulativeByteOffset)); - } - - ResultPtr = Builder.CreateIntToPtr(ResultPtr, Variadic->getType()); - Variadic->replaceAllUsesWith(ResultPtr); - Variadic->eraseFromParent(); -} - bool SeparateConstOffsetFromGEP::reorderGEP(GetElementPtrInst *GEP, TargetTransformInfo &TTI) { auto PtrGEP = dyn_cast<GetElementPtrInst>(GEP->getPointerOperand()); @@ -1091,8 +1005,8 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // Notice that we don't remove struct field indices here. If LowerGEP is // disabled, a structure index is not accumulated and we still use the old // one. If LowerGEP is enabled, a structure index is accumulated in the - // constant offset. LowerToSingleIndexGEPs or lowerToArithmetics will later - // handle the constant offset and won't need a new structure index. + // constant offset. LowerToSingleIndexGEPs will later handle the constant + // offset and won't need a new structure index. gep_type_iterator GTI = gep_type_begin(*GEP); for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { if (GTI.isSequential()) { @@ -1167,22 +1081,9 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { GEP->setNoWrapFlags(NewGEPFlags); - // Lowers a GEP to either GEPs with a single index or arithmetic operations. + // Lowers a GEP to GEPs with a single index. if (LowerGEP) { - // As currently BasicAA does not analyze ptrtoint/inttoptr, do not lower to - // arithmetic operations if the target uses alias analysis in codegen. - // Additionally, pointers that aren't integral (and so can't be safely - // converted to integers) or those whose offset size is different from their - // pointer size (which means that doing integer arithmetic on them could - // affect that data) can't be lowered in this way. - unsigned AddrSpace = GEP->getPointerAddressSpace(); - bool PointerHasExtraData = DL->getPointerSizeInBits(AddrSpace) != - DL->getIndexSizeInBits(AddrSpace); - if (TTI.useAA() || DL->isNonIntegralAddressSpace(AddrSpace) || - PointerHasExtraData) - lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset); - else - lowerToArithmetics(GEP, AccumulativeByteOffset); + lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset); return true; } diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 571fa11..1eb8996 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1249,7 +1249,8 @@ Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) { // offset, if the offset is simpler. const SCEV *Diff = SE.getMinusSCEV(S, ExitSCEV); const SCEV *Op = Diff; - match(Diff, m_scev_Mul(m_scev_AllOnes(), m_SCEV(Op))); + match(Op, m_scev_Add(m_SCEVConstant(), m_SCEV(Op))); + match(Op, m_scev_Mul(m_scev_AllOnes(), m_SCEV(Op))); match(Op, m_scev_PtrToInt(m_SCEV(Op))); if (!isa<SCEVConstant, SCEVUnknown>(Op)) continue; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 94b0ab8..674de57 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -198,6 +198,11 @@ static cl::opt<unsigned> MaxSwitchCasesPerResult( "max-switch-cases-per-result", cl::Hidden, cl::init(16), cl::desc("Limit cases to analyze when converting a switch to select")); +static cl::opt<unsigned> MaxJumpThreadingLiveBlocks( + "max-jump-threading-live-blocks", cl::Hidden, cl::init(24), + cl::desc("Limit number of blocks a define in a threaded block is allowed " + "to be live in")); + STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); @@ -3390,8 +3395,27 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI, return true; } +using BlocksSet = SmallPtrSet<BasicBlock *, 8>; + +// Return false if number of blocks searched is too much. +static bool findReaching(BasicBlock *BB, BasicBlock *DefBB, + BlocksSet &ReachesNonLocalUses) { + if (BB == DefBB) + return true; + if (!ReachesNonLocalUses.insert(BB).second) + return true; + + if (ReachesNonLocalUses.size() > MaxJumpThreadingLiveBlocks) + return false; + for (BasicBlock *Pred : predecessors(BB)) + if (!findReaching(Pred, DefBB, ReachesNonLocalUses)) + return false; + return true; +} + /// Return true if we can thread a branch across this block. -static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { +static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB, + BlocksSet &NonLocalUseBlocks) { int Size = 0; EphemeralValueTracker EphTracker; @@ -3411,12 +3435,16 @@ static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { return false; // Don't clone large BB's. } - // We can only support instructions that do not define values that are - // live outside of the current basic block. + // Record blocks with non-local uses of values defined in the current basic + // block. for (User *U : I.users()) { Instruction *UI = cast<Instruction>(U); - if (UI->getParent() != BB || isa<PHINode>(UI)) - return false; + BasicBlock *UsedInBB = UI->getParent(); + if (UsedInBB == BB) { + if (isa<PHINode>(UI)) + return false; + } else + NonLocalUseBlocks.insert(UsedInBB); } // Looks ok, continue checking. @@ -3475,18 +3503,37 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, return false; // Now we know that this block has multiple preds and two succs. - // Check that the block is small enough and values defined in the block are - // not used outside of it. - if (!blockIsSimpleEnoughToThreadThrough(BB)) + // Check that the block is small enough and record which non-local blocks use + // values defined in the block. + + BlocksSet NonLocalUseBlocks; + BlocksSet ReachesNonLocalUseBlocks; + if (!blockIsSimpleEnoughToThreadThrough(BB, NonLocalUseBlocks)) return false; + // Jump-threading can only be done to destinations where no values defined + // in BB are live. + + // Quickly check if both destinations have uses. If so, jump-threading cannot + // be done. + if (NonLocalUseBlocks.contains(BI->getSuccessor(0)) && + NonLocalUseBlocks.contains(BI->getSuccessor(1))) + return false; + + // Search backward from NonLocalUseBlocks to find which blocks + // reach non-local uses. + for (BasicBlock *UseBB : NonLocalUseBlocks) + // Give up if too many blocks are searched. + if (!findReaching(UseBB, BB, ReachesNonLocalUseBlocks)) + return false; + for (const auto &Pair : KnownValues) { - // Okay, we now know that all edges from PredBB should be revectored to - // branch to RealDest. ConstantInt *CB = Pair.first; ArrayRef<BasicBlock *> PredBBs = Pair.second.getArrayRef(); BasicBlock *RealDest = BI->getSuccessor(!CB->getZExtValue()); + // Okay, we now know that all edges from PredBB should be revectored to + // branch to RealDest. if (RealDest == BB) continue; // Skip self loops. @@ -3496,6 +3543,10 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, })) continue; + // Only revector to RealDest if no values defined in BB are live. + if (ReachesNonLocalUseBlocks.contains(RealDest)) + continue; + LLVM_DEBUG({ dbgs() << "Condition " << *Cond << " in " << BB->getName() << " has value " << *Pair.first << " in predecessors:\n"; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 969d225..c47fd942 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -1665,13 +1665,12 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { // Keep a record of all the exiting blocks. SmallVector<const SCEVPredicate *, 4> Predicates; - std::optional<std::pair<BasicBlock *, BasicBlock *>> SingleUncountableEdge; + BasicBlock *SingleUncountableExitingBlock = nullptr; for (BasicBlock *BB : ExitingBlocks) { const SCEV *EC = PSE.getSE()->getPredicatedExitCount(TheLoop, BB, &Predicates); if (isa<SCEVCouldNotCompute>(EC)) { - SmallVector<BasicBlock *, 2> Succs(successors(BB)); - if (Succs.size() != 2) { + if (size(successors(BB)) != 2) { reportVectorizationFailure( "Early exiting block does not have exactly two successors", "Incorrect number of successors from early exiting block", @@ -1679,15 +1678,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { return false; } - BasicBlock *ExitBlock; - if (!TheLoop->contains(Succs[0])) - ExitBlock = Succs[0]; - else { - assert(!TheLoop->contains(Succs[1])); - ExitBlock = Succs[1]; - } - - if (SingleUncountableEdge) { + if (SingleUncountableExitingBlock) { reportVectorizationFailure( "Loop has too many uncountable exits", "Cannot vectorize early exit loop with more than one early exit", @@ -1695,7 +1686,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { return false; } - SingleUncountableEdge = {BB, ExitBlock}; + SingleUncountableExitingBlock = BB; } else CountableExitingBlocks.push_back(BB); } @@ -1705,7 +1696,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { // PSE.getSymbolicMaxBackedgeTakenCount() below. Predicates.clear(); - if (!SingleUncountableEdge) { + if (!SingleUncountableExitingBlock) { LLVM_DEBUG(dbgs() << "LV: Cound not find any uncountable exits"); return false; } @@ -1713,7 +1704,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { // The only supported early exit loops so far are ones where the early // exiting block is a unique predecessor of the latch block. BasicBlock *LatchPredBB = LatchBB->getUniquePredecessor(); - if (LatchPredBB != SingleUncountableEdge->first) { + if (LatchPredBB != SingleUncountableExitingBlock) { reportVectorizationFailure("Early exit is not the latch predecessor", "Cannot vectorize early exit loop", "EarlyExitNotLatchPredecessor", ORE, TheLoop); @@ -1766,7 +1757,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { } // The vectoriser cannot handle loads that occur after the early exit block. - assert(LatchBB->getUniquePredecessor() == SingleUncountableEdge->first && + assert(LatchBB->getUniquePredecessor() == SingleUncountableExitingBlock && "Expected latch predecessor to be the early exiting block"); // TODO: Handle loops that may fault. @@ -1789,7 +1780,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { LLVM_DEBUG(dbgs() << "LV: Found an early exit loop with symbolic max " "backedge taken count: " << *SymbolicMaxBTC << '\n'); - UncountableEdge = SingleUncountableEdge; + UncountableExitingBB = SingleUncountableExitingBlock; return true; } @@ -1861,7 +1852,8 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { return false; } else { if (!isVectorizableEarlyExitLoop()) { - UncountableEdge = std::nullopt; + assert(!hasUncountableEarlyExit() && + "Must be false without vectorizable early-exit loop"); if (DoExtraAnalysis) Result = false; else diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index f57ce0c..ea0fa06 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -170,8 +170,7 @@ public: new VPInstruction(Opcode, Operands, Flags, DL, Name)); } - VPInstruction *createNaryOp(unsigned Opcode, - std::initializer_list<VPValue *> Operands, + VPInstruction *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands, Type *ResultTy, const VPIRFlags &Flags = {}, DebugLoc DL = DebugLoc::getUnknown(), const Twine &Name = "") { @@ -180,7 +179,7 @@ public: } VPInstruction *createOverflowingOp(unsigned Opcode, - std::initializer_list<VPValue *> Operands, + ArrayRef<VPValue *> Operands, VPRecipeWithIRFlags::WrapFlagsTy WrapFlags, DebugLoc DL = DebugLoc::getUnknown(), const Twine &Name = "") { diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index fe93fcd..b4ea70e 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1624,7 +1624,7 @@ private: /// presence of a cost for an instruction in the mapping indicates that the /// instruction will be scalarized when vectorizing with the associated /// vectorization factor. The entries are VF-ScalarCostTy pairs. - DenseMap<ElementCount, ScalarCostsTy> InstsToScalarize; + MapVector<ElementCount, ScalarCostsTy> InstsToScalarize; /// Holds the instructions known to be uniform after vectorization. /// The data is collected per VF. @@ -3150,7 +3150,7 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( isa<LoadInst>(I) && Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed(); bool StoreAccessWithGapsRequiresMasking = - isa<StoreInst>(I) && (Group->getNumMembers() < Group->getFactor()); + isa<StoreInst>(I) && !Group->isFull(); if (!PredicatedAccessRequiresMasking && !LoadAccessWithGapsRequiresEpilogMasking && !StoreAccessWithGapsRequiresMasking) @@ -5372,7 +5372,7 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, // Calculate the cost of the whole interleaved group. bool UseMaskForGaps = (Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed()) || - (isa<StoreInst>(I) && (Group->getNumMembers() < Group->getFactor())); + (isa<StoreInst>(I) && !Group->isFull()); InstructionCost Cost = TTI.getInterleavedMemoryOpCost( InsertPos->getOpcode(), WideVecTy, Group->getFactor(), Indices, Group->getAlign(), AS, CostKind, Legal->isMaskRequired(I), @@ -9788,6 +9788,10 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L, match( P.getIncomingValueForBlock(EPI.MainLoopIterationCountCheck), m_SpecificInt(0)) && + any_of(P.incoming_values(), + [&EPI](Value *Inc) { + return Inc == EPI.VectorTripCount; + }) && all_of(P.incoming_values(), [&EPI](Value *Inc) { return Inc == EPI.VectorTripCount || match(Inc, m_SpecificInt(0)); diff --git a/llvm/lib/Transforms/Vectorize/VPlanCFG.h b/llvm/lib/Transforms/Vectorize/VPlanCFG.h index b77aa9d..c79485c 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanCFG.h +++ b/llvm/lib/Transforms/Vectorize/VPlanCFG.h @@ -231,6 +231,13 @@ vp_post_order_shallow(VPBlockBase *G) { } /// Returns an iterator range to traverse the graph starting at \p G in +/// post order while traversing through region blocks. +inline iterator_range<po_iterator<VPBlockDeepTraversalWrapper<VPBlockBase *>>> +vp_post_order_deep(VPBlockBase *G) { + return post_order(VPBlockDeepTraversalWrapper<VPBlockBase *>(G)); +} + +/// Returns an iterator range to traverse the graph starting at \p G in /// depth-first order while traversing through region blocks. inline iterator_range<df_iterator<VPBlockDeepTraversalWrapper<VPBlockBase *>>> vp_depth_first_deep(VPBlockBase *G) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 68e7c20..98d11f0 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2530,8 +2530,8 @@ void VPReductionRecipe::execute(VPTransformState &State) { NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain); else NextInChain = State.Builder.CreateBinOp( - (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), NewRed, - PrevInChain); + (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), + PrevInChain, NewRed); } State.set(this, NextInChain, /*IsScalar*/ true); } @@ -2836,12 +2836,12 @@ static void scalarizeInstruction(const Instruction *Instr, Instruction *Cloned = Instr->clone(); if (!IsVoidRetTy) { Cloned->setName(Instr->getName() + ".cloned"); -#if !defined(NDEBUG) - // Verify that VPlan type inference results agree with the type of the - // generated values. - assert(State.TypeAnalysis.inferScalarType(RepRecipe) == Cloned->getType() && - "inferred type and type from generated instructions do not match"); -#endif + Type *ResultTy = State.TypeAnalysis.inferScalarType(RepRecipe); + // The operands of the replicate recipe may have been narrowed, resulting in + // a narrower result type. Update the type of the cloned instruction to the + // correct type. + if (ResultTy != Cloned->getType()) + Cloned->mutateType(ResultTy); } RepRecipe->applyFlags(*Cloned); @@ -3548,6 +3548,8 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { // Vectorize the interleaved store group. Value *MaskForGaps = createBitMaskForGaps(State.Builder, State.VF.getKnownMinValue(), *Group); + assert(((MaskForGaps != nullptr) == NeedsMaskForGaps) && + "Mismatch between NeedsMaskForGaps and MaskForGaps"); assert((!MaskForGaps || !State.VF.isScalable()) && "masking gaps for scalable vectors is not yet supported."); ArrayRef<VPValue *> StoredValues = getStoredValues(); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 47a9ff0..373b2f7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -32,11 +32,11 @@ #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/TypeSize.h" using namespace llvm; +using namespace VPlanPatternMatch; bool VPlanTransforms::tryToConvertVPInstructionsToVPRecipes( VPlanPtr &Plan, @@ -528,13 +528,11 @@ static void removeRedundantCanonicalIVs(VPlan &Plan) { /// Returns true if \p R is dead and can be removed. static bool isDeadRecipe(VPRecipeBase &R) { - using namespace llvm::PatternMatch; // Do remove conditional assume instructions as their conditions may be // flattened. auto *RepR = dyn_cast<VPReplicateRecipe>(&R); - bool IsConditionalAssume = - RepR && RepR->isPredicated() && - match(RepR->getUnderlyingInstr(), m_Intrinsic<Intrinsic::assume>()); + bool IsConditionalAssume = RepR && RepR->isPredicated() && + match(RepR, m_Intrinsic<Intrinsic::assume>()); if (IsConditionalAssume) return true; @@ -547,10 +545,8 @@ static bool isDeadRecipe(VPRecipeBase &R) { } void VPlanTransforms::removeDeadRecipes(VPlan &Plan) { - ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( - Plan.getEntry()); - - for (VPBasicBlock *VPBB : reverse(VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT))) { + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( + vp_post_order_deep(Plan.getEntry()))) { // The recipes in the block are processed in reverse order, to catch chains // of dead recipes. for (VPRecipeBase &R : make_early_inc_range(reverse(*VPBB))) { @@ -625,7 +621,6 @@ static SmallVector<VPUser *> collectUsersRecursively(VPValue *V) { /// original IV's users. This is an optional optimization to reduce the needs of /// vector extracts. static void legalizeAndOptimizeInductions(VPlan &Plan) { - using namespace llvm::VPlanPatternMatch; VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); bool HasOnlyVectorVFs = !Plan.hasScalarVFOnly(); VPBuilder Builder(HeaderVPBB, HeaderVPBB->getFirstNonPhi()); @@ -727,7 +722,6 @@ static VPWidenInductionRecipe *getOptimizableIVOf(VPValue *VPV) { return nullptr; auto IsWideIVInc = [&]() { - using namespace VPlanPatternMatch; auto &ID = WideIV->getInductionDescriptor(); // Check if VPV increments the induction by the induction step. @@ -771,8 +765,6 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan &Plan, VPTypeAnalysis &TypeInfo, VPBlockBase *PredVPBB, VPValue *Op) { - using namespace VPlanPatternMatch; - VPValue *Incoming, *Mask; if (!match(Op, m_VPInstruction<VPInstruction::ExtractLane>( m_VPInstruction<VPInstruction::FirstActiveLane>( @@ -827,8 +819,6 @@ static VPValue * optimizeLatchExitInductionUser(VPlan &Plan, VPTypeAnalysis &TypeInfo, VPBlockBase *PredVPBB, VPValue *Op, DenseMap<VPValue *, VPValue *> &EndValues) { - using namespace VPlanPatternMatch; - VPValue *Incoming; if (!match(Op, m_VPInstruction<VPInstruction::ExtractLastElement>( m_VPValue(Incoming)))) @@ -986,7 +976,6 @@ static Value *tryToFoldLiveIns(const VPRecipeBase &R, unsigned Opcode, /// Try to simplify recipe \p R. static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { - using namespace llvm::VPlanPatternMatch; VPlan *Plan = R.getParent()->getPlan(); auto *Def = dyn_cast<VPSingleDefRecipe>(&R); @@ -1269,7 +1258,6 @@ static void narrowToSingleScalarRecipes(VPlan &Plan) { /// Normalize and simplify VPBlendRecipes. Should be run after simplifyRecipes /// to make sure the masks are simplified. static void simplifyBlends(VPlan &Plan) { - using namespace llvm::VPlanPatternMatch; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( vp_depth_first_shallow(Plan.getVectorLoopRegion()->getEntry()))) { for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { @@ -1393,7 +1381,6 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan, // Currently only handle cases where the single user is a header-mask // comparison with the backedge-taken-count. - using namespace VPlanPatternMatch; if (!match( *WideIV->user_begin(), m_Binary<Instruction::ICmp>( @@ -1424,8 +1411,7 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan, static bool isConditionTrueViaVFAndUF(VPValue *Cond, VPlan &Plan, ElementCount BestVF, unsigned BestUF, ScalarEvolution &SE) { - using namespace llvm::VPlanPatternMatch; - if (match(Cond, m_Binary<Instruction::Or>(m_VPValue(), m_VPValue()))) + if (match(Cond, m_BinaryOr(m_VPValue(), m_VPValue()))) return any_of(Cond->getDefiningRecipe()->operands(), [&Plan, BestVF, BestUF, &SE](VPValue *C) { return isConditionTrueViaVFAndUF(C, Plan, BestVF, BestUF, SE); @@ -1443,15 +1429,15 @@ static bool isConditionTrueViaVFAndUF(VPValue *Cond, VPlan &Plan, // count is not conveniently available as SCEV so far, so we compare directly // against the original trip count. This is stricter than necessary, as we // will only return true if the trip count == vector trip count. - // TODO: Use SCEV for vector trip count once available, to cover cases where - // vector trip count == UF * VF, but original trip count != UF * VF. - const SCEV *TripCount = - vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE); - assert(!isa<SCEVCouldNotCompute>(TripCount) && + const SCEV *VectorTripCount = + vputils::getSCEVExprForVPValue(&Plan.getVectorTripCount(), SE); + if (isa<SCEVCouldNotCompute>(VectorTripCount)) + VectorTripCount = vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE); + assert(!isa<SCEVCouldNotCompute>(VectorTripCount) && "Trip count SCEV must be computable"); ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF); - const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements); - return SE.isKnownPredicate(CmpInst::ICMP_EQ, TripCount, C); + const SCEV *C = SE.getElementCount(VectorTripCount->getType(), NumElements); + return SE.isKnownPredicate(CmpInst::ICMP_EQ, VectorTripCount, C); } /// Try to simplify the branch condition of \p Plan. This may restrict the @@ -1464,7 +1450,6 @@ static bool simplifyBranchConditionForVFAndUF(VPlan &Plan, ElementCount BestVF, auto *Term = &ExitingVPBB->back(); VPValue *Cond; ScalarEvolution &SE = *PSE.getSE(); - using namespace llvm::VPlanPatternMatch; if (match(Term, m_BranchOnCount(m_VPValue(), m_VPValue())) || match(Term, m_BranchOnCond( m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue()))))) { @@ -1496,11 +1481,11 @@ static bool simplifyBranchConditionForVFAndUF(VPlan &Plan, ElementCount BestVF, auto *CanIVTy = Plan.getCanonicalIV()->getScalarType(); if (all_of(Header->phis(), IsaPred<VPCanonicalIVPHIRecipe, VPEVLBasedIVPHIRecipe, - VPFirstOrderRecurrencePHIRecipe>)) { + VPFirstOrderRecurrencePHIRecipe, VPPhi>)) { for (VPRecipeBase &HeaderR : make_early_inc_range(Header->phis())) { - auto *HeaderPhiR = cast<VPHeaderPHIRecipe>(&HeaderR); - HeaderPhiR->replaceAllUsesWith(HeaderPhiR->getStartValue()); - HeaderPhiR->eraseFromParent(); + auto *Phi = cast<VPPhiAccessors>(&HeaderR); + HeaderR.getVPSingleValue()->replaceAllUsesWith(Phi->getIncomingValue(0)); + HeaderR.eraseFromParent(); } VPBlockBase *Preheader = VectorRegion->getSinglePredecessor(); @@ -1847,7 +1832,6 @@ void VPlanTransforms::truncateToMinimalBitwidths( if (auto *VPW = dyn_cast<VPRecipeWithIRFlags>(&R)) VPW->dropPoisonGeneratingFlags(); - using namespace llvm::VPlanPatternMatch; if (OldResSizeInBits != NewResSizeInBits && !match(&R, m_Binary<Instruction::ICmp>(m_VPValue(), m_VPValue()))) { // Extend result to original width. @@ -1897,7 +1881,6 @@ void VPlanTransforms::truncateToMinimalBitwidths( } void VPlanTransforms::removeBranchOnConst(VPlan &Plan) { - using namespace llvm::VPlanPatternMatch; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( vp_depth_first_shallow(Plan.getEntry()))) { VPValue *Cond; @@ -2143,7 +2126,6 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask, VPRecipeBase &CurRecipe, VPTypeAnalysis &TypeInfo, VPValue &AllOneMask, VPValue &EVL) { - using namespace llvm::VPlanPatternMatch; auto GetNewMask = [&](VPValue *OrigMask) -> VPValue * { assert(OrigMask && "Unmasked recipe when folding tail"); // HeaderMask will be handled using EVL. @@ -2223,7 +2205,6 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( vp_depth_first_deep(Plan.getVectorLoopRegion()->getEntry()))) { for (VPRecipeBase &R : *VPBB) { - using namespace VPlanPatternMatch; VPValue *V1, *V2; if (!match(&R, m_VPInstruction<VPInstruction::FirstOrderRecurrenceSplice>( @@ -2309,10 +2290,12 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { /// ... /// %EVLPhi = EXPLICIT-VECTOR-LENGTH-BASED-IV-PHI [ %StartV, %vector.ph ], /// [ %NextEVLIV, %vector.body ] -/// %AVL = sub original TC, %EVLPhi +/// %AVL = phi [ trip-count, %vector.ph ], [ %NextAVL, %vector.body ] /// %VPEVL = EXPLICIT-VECTOR-LENGTH %AVL /// ... -/// %NextEVLIV = add IVSize (cast i32 %VPEVVL to IVSize), %EVLPhi +/// %OpEVL = cast i32 %VPEVL to IVSize +/// %NextEVLIV = add IVSize %OpEVL, %EVLPhi +/// %NextAVL = sub IVSize nuw %AVL, %OpEVL /// ... /// /// If MaxSafeElements is provided, the function adds the following recipes: @@ -2323,12 +2306,14 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { /// ... /// %EVLPhi = EXPLICIT-VECTOR-LENGTH-BASED-IV-PHI [ %StartV, %vector.ph ], /// [ %NextEVLIV, %vector.body ] -/// %AVL = sub original TC, %EVLPhi +/// %AVL = phi [ trip-count, %vector.ph ], [ %NextAVL, %vector.body ] /// %cmp = cmp ult %AVL, MaxSafeElements /// %SAFE_AVL = select %cmp, %AVL, MaxSafeElements /// %VPEVL = EXPLICIT-VECTOR-LENGTH %SAFE_AVL /// ... -/// %NextEVLIV = add IVSize (cast i32 %VPEVL to IVSize), %EVLPhi +/// %OpEVL = cast i32 %VPEVL to IVSize +/// %NextEVLIV = add IVSize %OpEVL, %EVLPhi +/// %NextAVL = sub IVSize nuw %AVL, %OpEVL /// ... /// bool VPlanTransforms::tryAddExplicitVectorLength( @@ -2350,9 +2335,12 @@ bool VPlanTransforms::tryAddExplicitVectorLength( auto *EVLPhi = new VPEVLBasedIVPHIRecipe(StartV, DebugLoc()); EVLPhi->insertAfter(CanonicalIVPHI); VPBuilder Builder(Header, Header->getFirstNonPhi()); - // Compute original TC - IV as the AVL (application vector length). - VPValue *AVL = Builder.createNaryOp( - Instruction::Sub, {Plan.getTripCount(), EVLPhi}, DebugLoc(), "avl"); + // Create the AVL (application vector length), starting from TC -> 0 in steps + // of EVL. + VPPhi *AVLPhi = Builder.createScalarPhi( + {Plan.getTripCount()}, DebugLoc::getCompilerGenerated(), "avl"); + VPValue *AVL = AVLPhi; + if (MaxSafeElements) { // Support for MaxSafeDist for correct loop emission. VPValue *AVLSafe = @@ -2379,6 +2367,11 @@ bool VPlanTransforms::tryAddExplicitVectorLength( CanonicalIVIncrement->getDebugLoc(), "index.evl.next"); EVLPhi->addOperand(NextEVLIV); + VPValue *NextAVL = Builder.createOverflowingOp( + Instruction::Sub, {AVLPhi, OpVPEVL}, {/*hasNUW=*/true, /*hasNSW=*/false}, + DebugLoc::getCompilerGenerated(), "avl.next"); + AVLPhi->addOperand(NextAVL); + transformRecipestoEVLRecipes(Plan, *VPEVL); // Replace all uses of VPCanonicalIVPHIRecipe by @@ -2391,7 +2384,6 @@ bool VPlanTransforms::tryAddExplicitVectorLength( } void VPlanTransforms::canonicalizeEVLLoops(VPlan &Plan) { - using namespace llvm::VPlanPatternMatch; // Find EVL loop entries by locating VPEVLBasedIVPHIRecipe. // There should be only one EVL PHI in the entire plan. VPEVLBasedIVPHIRecipe *EVLPhi = nullptr; @@ -2480,7 +2472,6 @@ void VPlanTransforms::dropPoisonGeneratingRecipes( // drop them directly. if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) { VPValue *A, *B; - using namespace llvm::VPlanPatternMatch; // Dropping disjoint from an OR may yield incorrect results, as some // analysis may have converted it to an Add implicitly (e.g. SCEV used // for dependence analysis). Instead, replace it with an equivalent Add. @@ -2570,7 +2561,8 @@ void VPlanTransforms::createInterleaveGroups( } bool NeedsMaskForGaps = - IG->requiresScalarEpilogue() && !ScalarEpilogueAllowed; + (IG->requiresScalarEpilogue() && !ScalarEpilogueAllowed) || + (!StoredValues.empty() && !IG->isFull()); Instruction *IRInsertPos = IG->getInsertPos(); auto *InsertPos = @@ -2774,8 +2766,6 @@ void VPlanTransforms::dissolveLoopRegions(VPlan &Plan) { void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan, Type &CanonicalIVTy) { - using namespace llvm::VPlanPatternMatch; - VPTypeAnalysis TypeInfo(&CanonicalIVTy); SmallVector<VPRecipeBase *> ToRemove; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( @@ -2852,8 +2842,6 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan, void VPlanTransforms::handleUncountableEarlyExit( VPBasicBlock *EarlyExitingVPBB, VPBasicBlock *EarlyExitVPBB, VPlan &Plan, VPBasicBlock *HeaderVPBB, VPBasicBlock *LatchVPBB, VFRange &Range) { - using namespace llvm::VPlanPatternMatch; - VPBlockBase *MiddleVPBB = LatchVPBB->getSuccessors()[0]; if (!EarlyExitVPBB->getSinglePredecessor() && EarlyExitVPBB->getPredecessors()[1] == MiddleVPBB) { @@ -2947,8 +2935,6 @@ void VPlanTransforms::handleUncountableEarlyExit( static VPExpressionRecipe * tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VFRange &Range) { - using namespace VPlanPatternMatch; - Type *RedTy = Ctx.Types.inferScalarType(Red); VPValue *VecOp = Red->getVecOp(); @@ -2994,8 +2980,6 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, static VPExpressionRecipe * tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VFRange &Range) { - using namespace VPlanPatternMatch; - unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); if (Opcode != Instruction::Add) return nullptr; @@ -3209,9 +3193,7 @@ static bool canNarrowLoad(VPWidenRecipe *WideMember0, unsigned OpIdx, return !W->getMask() && WideMember0->getOperand(OpIdx) == OpV; if (auto *IR = dyn_cast<VPInterleaveRecipe>(DefR)) - return IR->getInterleaveGroup()->getFactor() == - IR->getInterleaveGroup()->getNumMembers() && - IR->getVPValue(Idx) == OpV; + return IR->getInterleaveGroup()->isFull() && IR->getVPValue(Idx) == OpV; return false; } @@ -3258,7 +3240,6 @@ static bool isAlreadyNarrow(VPValue *VPV) { void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF, unsigned VectorRegWidth) { - using namespace llvm::VPlanPatternMatch; VPRegionBlock *VectorLoop = Plan.getVectorLoopRegion(); if (VF.isScalable() || !VectorLoop) return; @@ -3328,9 +3309,7 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF, if (!DefR) return false; auto *IR = dyn_cast<VPInterleaveRecipe>(DefR); - return IR && - IR->getInterleaveGroup()->getFactor() == - IR->getInterleaveGroup()->getNumMembers() && + return IR && IR->getInterleaveGroup()->isFull() && IR->getVPValue(Op.index()) == Op.value(); })) { StoreGroups.push_back(InterleaveR); diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index 81bd21b..14f20c6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -73,8 +73,11 @@ bool vputils::isHeaderMask(const VPValue *V, VPlan &Plan) { } const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) { - if (V->isLiveIn()) - return SE.getSCEV(V->getLiveInIRValue()); + if (V->isLiveIn()) { + if (Value *LiveIn = V->getLiveInIRValue()) + return SE.getSCEV(LiveIn); + return SE.getCouldNotCompute(); + } // TODO: Support constructing SCEVs for more recipes as needed. return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe()) diff --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index 57d01cb..14ae4f2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -79,9 +79,8 @@ bool VPlanVerifier::verifyPhiRecipes(const VPBasicBlock *VPBB) { if (isa<VPActiveLaneMaskPHIRecipe>(RecipeI)) NumActiveLaneMaskPhiRecipes++; - if (IsHeaderVPBB && !isa<VPHeaderPHIRecipe, VPWidenPHIRecipe>(*RecipeI) && - !isa<VPInstruction>(*RecipeI) && - cast<VPInstruction>(RecipeI)->getOpcode() == Instruction::PHI) { + if (IsHeaderVPBB && + !isa<VPHeaderPHIRecipe, VPWidenPHIRecipe, VPPhi>(*RecipeI)) { errs() << "Found non-header PHI recipe in header VPBB"; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) errs() << ": "; @@ -173,7 +172,8 @@ bool VPlanVerifier::verifyEVLRecipe(const VPInstruction &EVL) const { [&](const VPInstructionWithType *S) { return VerifyEVLUse(*S, 0); }) .Case<VPInstruction>([&](const VPInstruction *I) { if (I->getOpcode() == Instruction::PHI || - I->getOpcode() == Instruction::ICmp) + I->getOpcode() == Instruction::ICmp || + I->getOpcode() == Instruction::Sub) return VerifyEVLUse(*I, 1); switch (I->getOpcode()) { case Instruction::Add: diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 6252f4f..6345b18 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1664,6 +1664,8 @@ static Align computeAlignmentAfterScalarization(Align VectorAlignment, // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1 // store i32 %b, i32* %1 bool VectorCombine::foldSingleElementStore(Instruction &I) { + if (!TTI.allowVectorElementIndexingUsingGEP()) + return false; auto *SI = cast<StoreInst>(&I); if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType())) return false; @@ -1719,6 +1721,9 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { /// Try to scalarize vector loads feeding extractelement instructions. bool VectorCombine::scalarizeLoadExtract(Instruction &I) { + if (!TTI.allowVectorElementIndexingUsingGEP()) + return false; + Value *Ptr; if (!match(&I, m_Load(m_Value(Ptr)))) return false; @@ -1827,6 +1832,8 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { } bool VectorCombine::scalarizeExtExtract(Instruction &I) { + if (!TTI.allowVectorElementIndexingUsingGEP()) + return false; auto *Ext = dyn_cast<ZExtInst>(&I); if (!Ext) return false; |