diff options
Diffstat (limited to 'llvm/lib')
47 files changed, 1422 insertions, 516 deletions
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 87fae92..47dccde 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -234,9 +234,14 @@ static bool evaluatePtrAddRecAtMaxBTCWillNotWrap( // Check if we have a suitable dereferencable assumption we can use. if (!StartPtrV->canBeFreed()) { + Instruction *CtxI = &*L->getHeader()->getFirstNonPHIIt(); + if (BasicBlock *LoopPred = L->getLoopPredecessor()) { + if (isa<BranchInst>(LoopPred->getTerminator())) + CtxI = LoopPred->getTerminator(); + } + RetainedKnowledge DerefRK = getKnowledgeValidInContext( - StartPtrV, {Attribute::Dereferenceable}, *AC, - L->getLoopPredecessor()->getTerminator(), DT); + StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT); if (DerefRK) { DerefBytesSCEV = SE.getUMaxExpr(DerefBytesSCEV, SE.getSCEV(DerefRK.IRArgValue)); @@ -2856,8 +2861,9 @@ void LoopAccessInfo::emitUnsafeDependenceRemark() { } } -bool LoopAccessInfo::blockNeedsPredication(BasicBlock *BB, Loop *TheLoop, - DominatorTree *DT) { +bool LoopAccessInfo::blockNeedsPredication(const BasicBlock *BB, + const Loop *TheLoop, + const DominatorTree *DT) { assert(TheLoop->contains(BB) && "Unknown block used"); // Blocks that do not dominate the latch need predication. diff --git a/llvm/lib/CodeGen/PeepholeOptimizer.cpp b/llvm/lib/CodeGen/PeepholeOptimizer.cpp index fb3e648..729a57e 100644 --- a/llvm/lib/CodeGen/PeepholeOptimizer.cpp +++ b/llvm/lib/CodeGen/PeepholeOptimizer.cpp @@ -1203,6 +1203,18 @@ bool PeepholeOptimizer::optimizeCoalescableCopyImpl(Rewriter &&CpyRewriter) { if (!NewSrc.Reg) continue; + if (NewSrc.SubReg) { + // Verify the register class supports the subregister index. ARM's + // copy-like queries return register:subreg pairs where the register's + // current class does not directly support the subregister index. + const TargetRegisterClass *RC = MRI->getRegClass(NewSrc.Reg); + const TargetRegisterClass *WithSubRC = + TRI->getSubClassWithSubReg(RC, NewSrc.SubReg); + if (!MRI->constrainRegClass(NewSrc.Reg, WithSubRC)) + continue; + Changed = true; + } + // Rewrite source. if (CpyRewriter.RewriteCurrentSource(NewSrc.Reg, NewSrc.SubReg)) { // We may have extended the live-range of NewSrc, account for that. @@ -1275,6 +1287,18 @@ MachineInstr &PeepholeOptimizer::rewriteSource(MachineInstr &CopyLike, const TargetRegisterClass *DefRC = MRI->getRegClass(Def.Reg); Register NewVReg = MRI->createVirtualRegister(DefRC); + if (NewSrc.SubReg) { + const TargetRegisterClass *NewSrcRC = MRI->getRegClass(NewSrc.Reg); + const TargetRegisterClass *WithSubRC = + TRI->getSubClassWithSubReg(NewSrcRC, NewSrc.SubReg); + + // The new source may not directly support the subregister, but we should be + // able to assume it is constrainable to support the subregister (otherwise + // ValueTracker was lying and reported a useless value). + if (!MRI->constrainRegClass(NewSrc.Reg, WithSubRC)) + llvm_unreachable("replacement register cannot support subregister"); + } + MachineInstr *NewCopy = BuildMI(*CopyLike.getParent(), &CopyLike, CopyLike.getDebugLoc(), TII->get(TargetOpcode::COPY), NewVReg) diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index 1a51830..54b92c9 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -516,19 +516,15 @@ static void PrintShuffleMask(raw_ostream &Out, Type *Ty, ArrayRef<int> Mask) { if (isa<ScalableVectorType>(Ty)) Out << "vscale x "; Out << Mask.size() << " x i32> "; - bool FirstElt = true; if (all_of(Mask, [](int Elt) { return Elt == 0; })) { Out << "zeroinitializer"; } else if (all_of(Mask, [](int Elt) { return Elt == PoisonMaskElem; })) { Out << "poison"; } else { Out << "<"; + ListSeparator LS; for (int Elt : Mask) { - if (FirstElt) - FirstElt = false; - else - Out << ", "; - Out << "i32 "; + Out << LS << "i32 "; if (Elt == PoisonMaskElem) Out << "poison"; else @@ -1700,14 +1696,12 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, if (const ConstantArray *CA = dyn_cast<ConstantArray>(CV)) { Type *ETy = CA->getType()->getElementType(); Out << '['; - WriterCtx.TypePrinter->print(ETy, Out); - Out << ' '; - WriteAsOperandInternal(Out, CA->getOperand(0), WriterCtx); - for (unsigned i = 1, e = CA->getNumOperands(); i != e; ++i) { - Out << ", "; + ListSeparator LS; + for (const Value *Op : CA->operands()) { + Out << LS; WriterCtx.TypePrinter->print(ETy, Out); Out << ' '; - WriteAsOperandInternal(Out, CA->getOperand(i), WriterCtx); + WriteAsOperandInternal(Out, Op, WriterCtx); } Out << ']'; return; @@ -1725,11 +1719,9 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, Type *ETy = CA->getType()->getElementType(); Out << '['; - WriterCtx.TypePrinter->print(ETy, Out); - Out << ' '; - WriteAsOperandInternal(Out, CA->getElementAsConstant(0), WriterCtx); - for (uint64_t i = 1, e = CA->getNumElements(); i != e; ++i) { - Out << ", "; + ListSeparator LS; + for (uint64_t i = 0, e = CA->getNumElements(); i != e; ++i) { + Out << LS; WriterCtx.TypePrinter->print(ETy, Out); Out << ' '; WriteAsOperandInternal(Out, CA->getElementAsConstant(i), WriterCtx); @@ -1742,24 +1734,17 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, if (CS->getType()->isPacked()) Out << '<'; Out << '{'; - unsigned N = CS->getNumOperands(); - if (N) { - Out << ' '; - WriterCtx.TypePrinter->print(CS->getOperand(0)->getType(), Out); + if (CS->getNumOperands() != 0) { Out << ' '; - - WriteAsOperandInternal(Out, CS->getOperand(0), WriterCtx); - - for (unsigned i = 1; i < N; i++) { - Out << ", "; - WriterCtx.TypePrinter->print(CS->getOperand(i)->getType(), Out); + ListSeparator LS; + for (const Value *Op : CS->operands()) { + Out << LS; + WriterCtx.TypePrinter->print(Op->getType(), Out); Out << ' '; - - WriteAsOperandInternal(Out, CS->getOperand(i), WriterCtx); + WriteAsOperandInternal(Out, Op, WriterCtx); } Out << ' '; } - Out << '}'; if (CS->getType()->isPacked()) Out << '>'; @@ -1787,11 +1772,9 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, } Out << '<'; - WriterCtx.TypePrinter->print(ETy, Out); - Out << ' '; - WriteAsOperandInternal(Out, CV->getAggregateElement(0U), WriterCtx); - for (unsigned i = 1, e = CVVTy->getNumElements(); i != e; ++i) { - Out << ", "; + ListSeparator LS; + for (unsigned i = 0, e = CVVTy->getNumElements(); i != e; ++i) { + Out << LS; WriterCtx.TypePrinter->print(ETy, Out); Out << ' '; WriteAsOperandInternal(Out, CV->getAggregateElement(i), WriterCtx); @@ -1848,13 +1831,12 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, Out << ", "; } - for (User::const_op_iterator OI = CE->op_begin(); OI != CE->op_end(); - ++OI) { - WriterCtx.TypePrinter->print((*OI)->getType(), Out); + ListSeparator LS; + for (const Value *Op : CE->operands()) { + Out << LS; + WriterCtx.TypePrinter->print(Op->getType(), Out); Out << ' '; - WriteAsOperandInternal(Out, *OI, WriterCtx); - if (OI+1 != CE->op_end()) - Out << ", "; + WriteAsOperandInternal(Out, Op, WriterCtx); } if (CE->isCast()) { @@ -1875,11 +1857,12 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, static void writeMDTuple(raw_ostream &Out, const MDTuple *Node, AsmWriterContext &WriterCtx) { Out << "!{"; - for (unsigned mi = 0, me = Node->getNumOperands(); mi != me; ++mi) { - const Metadata *MD = Node->getOperand(mi); - if (!MD) + ListSeparator LS; + for (const Metadata *MD : Node->operands()) { + Out << LS; + if (!MD) { Out << "null"; - else if (auto *MDV = dyn_cast<ValueAsMetadata>(MD)) { + } else if (auto *MDV = dyn_cast<ValueAsMetadata>(MD)) { Value *V = MDV->getValue(); WriterCtx.TypePrinter->print(V->getType(), Out); Out << ' '; @@ -1888,8 +1871,6 @@ static void writeMDTuple(raw_ostream &Out, const MDTuple *Node, WriteAsOperandInternal(Out, MD, WriterCtx); WriterCtx.onWriteMetadataAsOperand(MD); } - if (mi + 1 != me) - Out << ", "; } Out << "}"; @@ -1897,24 +1878,9 @@ static void writeMDTuple(raw_ostream &Out, const MDTuple *Node, namespace { -struct FieldSeparator { - bool Skip = true; - const char *Sep; - - FieldSeparator(const char *Sep = ", ") : Sep(Sep) {} -}; - -raw_ostream &operator<<(raw_ostream &OS, FieldSeparator &FS) { - if (FS.Skip) { - FS.Skip = false; - return OS; - } - return OS << FS.Sep; -} - struct MDFieldPrinter { raw_ostream &Out; - FieldSeparator FS; + ListSeparator FS; AsmWriterContext &WriterCtx; explicit MDFieldPrinter(raw_ostream &Out) @@ -2051,7 +2017,7 @@ void MDFieldPrinter::printDIFlags(StringRef Name, DINode::DIFlags Flags) { SmallVector<DINode::DIFlags, 8> SplitFlags; auto Extra = DINode::splitFlags(Flags, SplitFlags); - FieldSeparator FlagsFS(" | "); + ListSeparator FlagsFS(" | "); for (auto F : SplitFlags) { auto StringF = DINode::getFlagString(F); assert(!StringF.empty() && "Expected valid flag"); @@ -2075,7 +2041,7 @@ void MDFieldPrinter::printDISPFlags(StringRef Name, SmallVector<DISubprogram::DISPFlags, 8> SplitFlags; auto Extra = DISubprogram::splitFlags(Flags, SplitFlags); - FieldSeparator FlagsFS(" | "); + ListSeparator FlagsFS(" | "); for (auto F : SplitFlags) { auto StringF = DISubprogram::getFlagString(F); assert(!StringF.empty() && "Expected valid flag"); @@ -2124,7 +2090,7 @@ static void writeGenericDINode(raw_ostream &Out, const GenericDINode *N, Printer.printString("header", N->getHeader()); if (N->getNumDwarfOperands()) { Out << Printer.FS << "operands: {"; - FieldSeparator IFS; + ListSeparator IFS; for (auto &I : N->dwarf_operands()) { Out << IFS; writeMetadataAsOperand(Out, I, WriterCtx); @@ -2638,7 +2604,7 @@ static void writeDILabel(raw_ostream &Out, const DILabel *N, static void writeDIExpression(raw_ostream &Out, const DIExpression *N, AsmWriterContext &WriterCtx) { Out << "!DIExpression("; - FieldSeparator FS; + ListSeparator FS; if (N->isValid()) { for (const DIExpression::ExprOperand &Op : N->expr_ops()) { auto OpStr = dwarf::OperationEncodingString(Op.getOp()); @@ -2666,7 +2632,7 @@ static void writeDIArgList(raw_ostream &Out, const DIArgList *N, assert(FromValue && "Unexpected DIArgList metadata outside of value argument"); Out << "!DIArgList("; - FieldSeparator FS; + ListSeparator FS; MDFieldPrinter Printer(Out, WriterCtx); for (Metadata *Arg : N->getArgs()) { Out << FS; @@ -3073,15 +3039,11 @@ void AssemblyWriter::writeOperandBundles(const CallBase *Call) { Out << " [ "; - bool FirstBundle = true; + ListSeparator LS; for (unsigned i = 0, e = Call->getNumOperandBundles(); i != e; ++i) { OperandBundleUse BU = Call->getOperandBundleAt(i); - if (!FirstBundle) - Out << ", "; - FirstBundle = false; - - Out << '"'; + Out << LS << '"'; printEscapedString(BU.getTagName(), Out); Out << '"'; @@ -3229,7 +3191,7 @@ void AssemblyWriter::printModuleSummaryIndex() { Out << "path: \""; printEscapedString(ModPair.first, Out); Out << "\", hash: ("; - FieldSeparator FS; + ListSeparator FS; for (auto Hash : ModPair.second) Out << FS << Hash; Out << "))\n"; @@ -3347,7 +3309,7 @@ void AssemblyWriter::printTypeIdSummary(const TypeIdSummary &TIS) { printTypeTestResolution(TIS.TTRes); if (!TIS.WPDRes.empty()) { Out << ", wpdResolutions: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &WPDRes : TIS.WPDRes) { Out << FS; Out << "(offset: " << WPDRes.first << ", "; @@ -3362,7 +3324,7 @@ void AssemblyWriter::printTypeIdSummary(const TypeIdSummary &TIS) { void AssemblyWriter::printTypeIdCompatibleVtableSummary( const TypeIdCompatibleVtableInfo &TI) { Out << ", summary: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &P : TI) { Out << FS; Out << "(offset: " << P.AddressPointOffset << ", "; @@ -3374,7 +3336,7 @@ void AssemblyWriter::printTypeIdCompatibleVtableSummary( void AssemblyWriter::printArgs(const std::vector<uint64_t> &Args) { Out << "args: ("; - FieldSeparator FS; + ListSeparator FS; for (auto arg : Args) { Out << FS; Out << arg; @@ -3391,7 +3353,7 @@ void AssemblyWriter::printWPDRes(const WholeProgramDevirtResolution &WPDRes) { if (!WPDRes.ResByArg.empty()) { Out << ", resByArg: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &ResByArg : WPDRes.ResByArg) { Out << FS; printArgs(ResByArg.first); @@ -3451,7 +3413,7 @@ void AssemblyWriter::printGlobalVarSummary(const GlobalVarSummary *GS) { if (!VTableFuncs.empty()) { Out << ", vTableFuncs: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &P : VTableFuncs) { Out << FS; Out << "(virtFunc: ^" << Machine.getGUIDSlot(P.FuncVI.getGUID()) @@ -3528,7 +3490,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { if (!FS->calls().empty()) { Out << ", calls: ("; - FieldSeparator IFS; + ListSeparator IFS; for (auto &Call : FS->calls()) { Out << IFS; Out << "(callee: ^" << Machine.getGUIDSlot(Call.first.getGUID()); @@ -3566,22 +3528,22 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { if (!FS->allocs().empty()) { Out << ", allocs: ("; - FieldSeparator AFS; + ListSeparator AFS; for (auto &AI : FS->allocs()) { Out << AFS; Out << "(versions: ("; - FieldSeparator VFS; + ListSeparator VFS; for (auto V : AI.Versions) { Out << VFS; Out << AllocTypeName(V); } Out << "), memProf: ("; - FieldSeparator MIBFS; + ListSeparator MIBFS; for (auto &MIB : AI.MIBs) { Out << MIBFS; Out << "(type: " << AllocTypeName((uint8_t)MIB.AllocType); Out << ", stackIds: ("; - FieldSeparator SIDFS; + ListSeparator SIDFS; for (auto Id : MIB.StackIdIndices) { Out << SIDFS; Out << TheIndex->getStackIdAtIndex(Id); @@ -3595,7 +3557,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { if (!FS->callsites().empty()) { Out << ", callsites: ("; - FieldSeparator SNFS; + ListSeparator SNFS; for (auto &CI : FS->callsites()) { Out << SNFS; if (CI.Callee) @@ -3603,13 +3565,13 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { else Out << "(callee: null"; Out << ", clones: ("; - FieldSeparator VFS; + ListSeparator VFS; for (auto V : CI.Clones) { Out << VFS; Out << V; } Out << "), stackIds: ("; - FieldSeparator SIDFS; + ListSeparator SIDFS; for (auto Id : CI.StackIdIndices) { Out << SIDFS; Out << TheIndex->getStackIdAtIndex(Id); @@ -3625,7 +3587,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { if (!FS->paramAccesses().empty()) { Out << ", params: ("; - FieldSeparator IFS; + ListSeparator IFS; for (auto &PS : FS->paramAccesses()) { Out << IFS; Out << "(param: " << PS.ParamNo; @@ -3633,7 +3595,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { PrintRange(PS.Use); if (!PS.Calls.empty()) { Out << ", calls: ("; - FieldSeparator IFS; + ListSeparator IFS; for (auto &Call : PS.Calls) { Out << IFS; Out << "(callee: ^" << Machine.getGUIDSlot(Call.Callee.getGUID()); @@ -3653,11 +3615,11 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { void AssemblyWriter::printTypeIdInfo( const FunctionSummary::TypeIdInfo &TIDInfo) { Out << ", typeIdInfo: ("; - FieldSeparator TIDFS; + ListSeparator TIDFS; if (!TIDInfo.TypeTests.empty()) { Out << TIDFS; Out << "typeTests: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &GUID : TIDInfo.TypeTests) { auto TidIter = TheIndex->typeIds().equal_range(GUID); if (TidIter.first == TidIter.second) { @@ -3706,7 +3668,7 @@ void AssemblyWriter::printVFuncId(const FunctionSummary::VFuncId VFId) { return; } // Print all type id that correspond to this GUID. - FieldSeparator FS; + ListSeparator FS; for (const auto &[GUID, TypeIdPair] : make_range(TidIter)) { Out << FS; Out << "vFuncId: ("; @@ -3721,7 +3683,7 @@ void AssemblyWriter::printVFuncId(const FunctionSummary::VFuncId VFId) { void AssemblyWriter::printNonConstVCalls( const std::vector<FunctionSummary::VFuncId> &VCallList, const char *Tag) { Out << Tag << ": ("; - FieldSeparator FS; + ListSeparator FS; for (auto &VFuncId : VCallList) { Out << FS; printVFuncId(VFuncId); @@ -3733,7 +3695,7 @@ void AssemblyWriter::printConstVCalls( const std::vector<FunctionSummary::ConstVCall> &VCallList, const char *Tag) { Out << Tag << ": ("; - FieldSeparator FS; + ListSeparator FS; for (auto &ConstVCall : VCallList) { Out << FS; Out << "("; @@ -3774,7 +3736,7 @@ void AssemblyWriter::printSummary(const GlobalValueSummary &Summary) { auto RefList = Summary.refs(); if (!RefList.empty()) { Out << ", refs: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &Ref : RefList) { Out << FS; if (Ref.isReadOnly()) @@ -3797,7 +3759,7 @@ void AssemblyWriter::printSummaryInfo(unsigned Slot, const ValueInfo &VI) { Out << "guid: " << VI.getGUID(); if (!VI.getSummaryList().empty()) { Out << ", summaries: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &Summary : VI.getSummaryList()) { Out << FS; printSummary(*Summary); @@ -3835,13 +3797,11 @@ void AssemblyWriter::printNamedMDNode(const NamedMDNode *NMD) { Out << '!'; printMetadataIdentifier(NMD->getName(), Out); Out << " = !{"; - for (unsigned i = 0, e = NMD->getNumOperands(); i != e; ++i) { - if (i) - Out << ", "; - + ListSeparator LS; + for (const MDNode *Op : NMD->operands()) { + Out << LS; // Write DIExpressions inline. // FIXME: Ban DIExpressions in NamedMDNodes, they will serve no purpose. - MDNode *Op = NMD->getOperand(i); if (auto *Expr = dyn_cast<DIExpression>(Op)) { writeDIExpression(Out, Expr, AsmWriterContext::getEmpty()); continue; @@ -4192,11 +4152,10 @@ void AssemblyWriter::printFunction(const Function *F) { // Loop over the arguments, printing them... if (F->isDeclaration() && !IsForDebug) { // We're only interested in the type here - don't print argument names. + ListSeparator LS; for (unsigned I = 0, E = FT->getNumParams(); I != E; ++I) { - // Insert commas as we go... the first arg doesn't get a comma - if (I) - Out << ", "; - // Output type... + Out << LS; + // Output type. TypePrinter.print(FT->getParamType(I), Out); AttributeSet ArgAttrs = Attrs.getParamAttrs(I); @@ -4207,10 +4166,9 @@ void AssemblyWriter::printFunction(const Function *F) { } } else { // The arguments are meaningful here, print them in detail. + ListSeparator LS; for (const Argument &Arg : F->args()) { - // Insert commas as we go... the first arg doesn't get a comma - if (Arg.getArgNo() != 0) - Out << ", "; + Out << LS; printArgument(&Arg, Attrs.getParamAttrs(Arg.getArgNo())); } } @@ -4332,16 +4290,14 @@ void AssemblyWriter::printBasicBlock(const BasicBlock *BB) { // Output predecessors for the block. Out.PadToColumn(50); Out << ";"; - const_pred_iterator PI = pred_begin(BB), PE = pred_end(BB); - - if (PI == PE) { + if (pred_empty(BB)) { Out << " No predecessors!"; } else { Out << " preds = "; - writeOperand(*PI, false); - for (++PI; PI != PE; ++PI) { - Out << ", "; - writeOperand(*PI, false); + ListSeparator LS; + for (const BasicBlock *Pred : predecessors(BB)) { + Out << LS; + writeOperand(Pred, false); } } } @@ -4520,9 +4476,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { writeOperand(Operand, true); Out << ", ["; + ListSeparator LS; for (unsigned i = 1, e = I.getNumOperands(); i != e; ++i) { - if (i != 1) - Out << ", "; + Out << LS; writeOperand(I.getOperand(i), true); } Out << ']'; @@ -4531,9 +4487,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { TypePrinter.print(I.getType(), Out); Out << ' '; + ListSeparator LS; for (unsigned op = 0, Eop = PN->getNumIncomingValues(); op < Eop; ++op) { - if (op) Out << ", "; - Out << "[ "; + Out << LS << "[ "; writeOperand(PN->getIncomingValue(op), false); Out << ", "; writeOperand(PN->getIncomingBlock(op), false); Out << " ]"; } @@ -4570,12 +4526,10 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << " within "; writeOperand(CatchSwitch->getParentPad(), /*PrintType=*/false); Out << " ["; - unsigned Op = 0; + ListSeparator LS; for (const BasicBlock *PadBB : CatchSwitch->handlers()) { - if (Op > 0) - Out << ", "; + Out << LS; writeOperand(PadBB, /*PrintType=*/true); - ++Op; } Out << "] unwind "; if (const BasicBlock *UnwindDest = CatchSwitch->getUnwindDest()) @@ -4586,10 +4540,10 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << " within "; writeOperand(FPI->getParentPad(), /*PrintType=*/false); Out << " ["; - for (unsigned Op = 0, NumOps = FPI->arg_size(); Op < NumOps; ++Op) { - if (Op > 0) - Out << ", "; - writeOperand(FPI->getArgOperand(Op), /*PrintType=*/true); + ListSeparator LS; + for (const Value *Op : FPI->arg_operands()) { + Out << LS; + writeOperand(Op, /*PrintType=*/true); } Out << ']'; } else if (isa<ReturnInst>(I) && !Operand) { @@ -4635,9 +4589,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << ' '; writeOperand(Operand, false); Out << '('; + ListSeparator LS; for (unsigned op = 0, Eop = CI->arg_size(); op < Eop; ++op) { - if (op > 0) - Out << ", "; + Out << LS; writeParamOperand(CI->getArgOperand(op), PAL.getParamAttrs(op)); } @@ -4683,9 +4637,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << ' '; writeOperand(Operand, false); Out << '('; + ListSeparator LS; for (unsigned op = 0, Eop = II->arg_size(); op < Eop; ++op) { - if (op) - Out << ", "; + Out << LS; writeParamOperand(II->getArgOperand(op), PAL.getParamAttrs(op)); } @@ -4723,9 +4677,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << ' '; writeOperand(Operand, false); Out << '('; + ListSeparator ArgLS; for (unsigned op = 0, Eop = CBI->arg_size(); op < Eop; ++op) { - if (op) - Out << ", "; + Out << ArgLS; writeParamOperand(CBI->getArgOperand(op), PAL.getParamAttrs(op)); } @@ -4738,10 +4692,10 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << "\n to "; writeOperand(CBI->getDefaultDest(), true); Out << " ["; - for (unsigned i = 0, e = CBI->getNumIndirectDests(); i != e; ++i) { - if (i != 0) - Out << ", "; - writeOperand(CBI->getIndirectDest(i), true); + ListSeparator DestLS; + for (const BasicBlock *Dest : CBI->getIndirectDests()) { + Out << DestLS; + writeOperand(Dest, true); } Out << ']'; } else if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { @@ -4824,9 +4778,10 @@ void AssemblyWriter::printInstruction(const Instruction &I) { } Out << ' '; - for (unsigned i = 0, E = I.getNumOperands(); i != E; ++i) { - if (i) Out << ", "; - writeOperand(I.getOperand(i), PrintAllTypes); + ListSeparator LS; + for (const Value *Op : I.operands()) { + Out << LS; + writeOperand(Op, PrintAllTypes); } } diff --git a/llvm/lib/IR/Value.cpp b/llvm/lib/IR/Value.cpp index 4e8f359..e5e062d 100644 --- a/llvm/lib/IR/Value.cpp +++ b/llvm/lib/IR/Value.cpp @@ -1000,14 +1000,12 @@ Align Value::getPointerAlignment(const DataLayout &DL) const { ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(0)); return Align(CI->getLimitedValue()); } - } else if (auto *CstPtr = dyn_cast<Constant>(this)) { - // Strip pointer casts to avoid creating unnecessary ptrtoint expression - // if the only "reduction" is combining a bitcast + ptrtoint. - CstPtr = CstPtr->stripPointerCasts(); - if (auto *CstInt = dyn_cast_or_null<ConstantInt>(ConstantExpr::getPtrToInt( - const_cast<Constant *>(CstPtr), DL.getIntPtrType(getType()), - /*OnlyIfReduced=*/true))) { - size_t TrailingZeros = CstInt->getValue().countr_zero(); + } else if (auto *CE = dyn_cast<ConstantExpr>(this)) { + // Determine the alignment of inttoptr(C). + if (CE->getOpcode() == Instruction::IntToPtr && + isa<ConstantInt>(CE->getOperand(0))) { + ConstantInt *IntPtr = cast<ConstantInt>(CE->getOperand(0)); + size_t TrailingZeros = IntPtr->getValue().countr_zero(); // While the actual alignment may be large, elsewhere we have // an arbitrary upper alignmet limit, so let's clamp to it. return Align(TrailingZeros < Value::MaxAlignmentExponent diff --git a/llvm/lib/Object/OffloadBundle.cpp b/llvm/lib/Object/OffloadBundle.cpp index 0dd378e..a6a9628a 100644 --- a/llvm/lib/Object/OffloadBundle.cpp +++ b/llvm/lib/Object/OffloadBundle.cpp @@ -120,14 +120,15 @@ OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset, if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) return errorCodeToError(object_error::parse_failed); - OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName); + std::unique_ptr<OffloadBundleFatBin> TheBundle( + new OffloadBundleFatBin(Buf, FileName)); // Read the Bundle Entries Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset); if (Err) return Err; - return std::unique_ptr<OffloadBundleFatBin>(TheBundle); + return TheBundle; } Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) { diff --git a/llvm/lib/Support/Mustache.cpp b/llvm/lib/Support/Mustache.cpp index 646d7a0..178f970 100644 --- a/llvm/lib/Support/Mustache.cpp +++ b/llvm/lib/Support/Mustache.cpp @@ -56,6 +56,33 @@ static Accessor splitMustacheString(StringRef Str) { namespace llvm::mustache { +class MustacheOutputStream : public raw_ostream { +public: + MustacheOutputStream() = default; + ~MustacheOutputStream() override = default; + + virtual void suspendIndentation() {} + virtual void resumeIndentation() {} + +private: + void anchor() override; +}; + +void MustacheOutputStream::anchor() {} + +class RawMustacheOutputStream : public MustacheOutputStream { +public: + RawMustacheOutputStream(raw_ostream &OS) : OS(OS) { SetUnbuffered(); } + +private: + raw_ostream &OS; + + void write_impl(const char *Ptr, size_t Size) override { + OS.write(Ptr, Size); + } + uint64_t current_pos() const override { return OS.tell(); } +}; + class Token { public: enum class Type { @@ -156,29 +183,31 @@ public: void setIndentation(size_t NewIndentation) { Indentation = NewIndentation; }; - void render(const llvm::json::Value &Data, llvm::raw_ostream &OS); + void render(const llvm::json::Value &Data, MustacheOutputStream &OS); private: - void renderLambdas(const llvm::json::Value &Contexts, llvm::raw_ostream &OS, - Lambda &L); + void renderLambdas(const llvm::json::Value &Contexts, + MustacheOutputStream &OS, Lambda &L); void renderSectionLambdas(const llvm::json::Value &Contexts, - llvm::raw_ostream &OS, SectionLambda &L); + MustacheOutputStream &OS, SectionLambda &L); - void renderPartial(const llvm::json::Value &Contexts, llvm::raw_ostream &OS, - ASTNode *Partial); + void renderPartial(const llvm::json::Value &Contexts, + MustacheOutputStream &OS, ASTNode *Partial); - void renderChild(const llvm::json::Value &Context, llvm::raw_ostream &OS); + void renderChild(const llvm::json::Value &Context, MustacheOutputStream &OS); const llvm::json::Value *findContext(); - void renderRoot(const json::Value &CurrentCtx, raw_ostream &OS); - void renderText(raw_ostream &OS); - void renderPartial(const json::Value &CurrentCtx, raw_ostream &OS); - void renderVariable(const json::Value &CurrentCtx, raw_ostream &OS); - void renderUnescapeVariable(const json::Value &CurrentCtx, raw_ostream &OS); - void renderSection(const json::Value &CurrentCtx, raw_ostream &OS); - void renderInvertSection(const json::Value &CurrentCtx, raw_ostream &OS); + void renderRoot(const json::Value &CurrentCtx, MustacheOutputStream &OS); + void renderText(MustacheOutputStream &OS); + void renderPartial(const json::Value &CurrentCtx, MustacheOutputStream &OS); + void renderVariable(const json::Value &CurrentCtx, MustacheOutputStream &OS); + void renderUnescapeVariable(const json::Value &CurrentCtx, + MustacheOutputStream &OS); + void renderSection(const json::Value &CurrentCtx, MustacheOutputStream &OS); + void renderInvertSection(const json::Value &CurrentCtx, + MustacheOutputStream &OS); MustacheContext &Ctx; Type Ty; @@ -300,6 +329,36 @@ struct Tag { size_t StartPosition = StringRef::npos; }; +static const char *tagKindToString(Tag::Kind K) { + switch (K) { + case Tag::Kind::None: + return "None"; + case Tag::Kind::Normal: + return "Normal"; + case Tag::Kind::Triple: + return "Triple"; + } + llvm_unreachable("Unknown Tag::Kind"); +} + +static const char *jsonKindToString(json::Value::Kind K) { + switch (K) { + case json::Value::Kind::Null: + return "JSON_KIND_NULL"; + case json::Value::Kind::Boolean: + return "JSON_KIND_BOOLEAN"; + case json::Value::Kind::Number: + return "JSON_KIND_NUMBER"; + case json::Value::Kind::String: + return "JSON_KIND_STRING"; + case json::Value::Kind::Array: + return "JSON_KIND_ARRAY"; + case json::Value::Kind::Object: + return "JSON_KIND_OBJECT"; + } + llvm_unreachable("Unknown json::Value::Kind"); +} + static Tag findNextTag(StringRef Template, size_t StartPos, StringRef Open, StringRef Close) { const StringLiteral TripleOpen("{{{"); @@ -344,11 +403,10 @@ static Tag findNextTag(StringRef Template, size_t StartPos, StringRef Open, static std::optional<std::pair<StringRef, StringRef>> processTag(const Tag &T, SmallVectorImpl<Token> &Tokens) { - LLVM_DEBUG(dbgs() << " Found tag: \"" << T.FullMatch << "\", Content: \"" - << T.Content << "\"\n"); + LLVM_DEBUG(dbgs() << "[Tag] " << T.FullMatch << ", Content: " << T.Content + << ", Kind: " << tagKindToString(T.TagKind) << "\n"); if (T.TagKind == Tag::Kind::Triple) { Tokens.emplace_back(T.FullMatch.str(), "&" + T.Content.str(), '&'); - LLVM_DEBUG(dbgs() << " Created UnescapeVariable token.\n"); return std::nullopt; } StringRef Interpolated = T.Content; @@ -356,7 +414,6 @@ processTag(const Tag &T, SmallVectorImpl<Token> &Tokens) { if (!Interpolated.trim().starts_with("=")) { char Front = Interpolated.empty() ? ' ' : Interpolated.trim().front(); Tokens.emplace_back(RawBody, Interpolated.str(), Front); - LLVM_DEBUG(dbgs() << " Created tag token of type '" << Front << "'\n"); return std::nullopt; } Tokens.emplace_back(RawBody, Interpolated.str(), '='); @@ -366,8 +423,8 @@ processTag(const Tag &T, SmallVectorImpl<Token> &Tokens) { DelimSpec = DelimSpec.trim(); std::pair<StringRef, StringRef> Ret = DelimSpec.split(' '); - LLVM_DEBUG(dbgs() << " Found Set Delimiter tag. NewOpen='" << Ret.first - << "', NewClose='" << Ret.second << "'\n"); + LLVM_DEBUG(dbgs() << "[Set Delimiter] NewOpen: " << Ret.first + << ", NewClose: " << Ret.second << "\n"); return Ret; } @@ -376,15 +433,15 @@ processTag(const Tag &T, SmallVectorImpl<Token> &Tokens) { // but we don't support that here. An unescape variable // is represented only by {{& variable}}. static SmallVector<Token> tokenize(StringRef Template) { - LLVM_DEBUG(dbgs() << "Tokenizing template: \"" << Template << "\"\n"); + LLVM_DEBUG(dbgs() << "[Tokenize Template] \"" << Template << "\"\n"); SmallVector<Token> Tokens; SmallString<8> Open("{{"); SmallString<8> Close("}}"); size_t Start = 0; while (Start < Template.size()) { - LLVM_DEBUG(dbgs() << "Loop start. Start=" << Start << ", Open='" << Open - << "', Close='" << Close << "'\n"); + LLVM_DEBUG(dbgs() << "[Tokenize Loop] Start:" << Start << ", Open:'" << Open + << "', Close:'" << Close << "'\n"); Tag T = findNextTag(Template, Start, Open, Close); if (T.TagKind == Tag::Kind::None) { @@ -399,7 +456,6 @@ static SmallVector<Token> tokenize(StringRef Template) { if (T.StartPosition > Start) { StringRef Text = Template.substr(Start, T.StartPosition - Start); Tokens.emplace_back(Text.str()); - LLVM_DEBUG(dbgs() << " Created Text token: \"" << Text << "\"\n"); } if (auto NewDelims = processTag(T, Tokens)) { @@ -450,12 +506,11 @@ static SmallVector<Token> tokenize(StringRef Template) { if ((!HasTextBehind && !HasTextAhead) || (!HasTextBehind && Idx == LastIdx)) stripTokenBefore(Tokens, Idx, CurrentToken, CurrentType); } - LLVM_DEBUG(dbgs() << "Tokenizing finished.\n"); return Tokens; } // Custom stream to escape strings. -class EscapeStringStream : public raw_ostream { +class EscapeStringStream : public MustacheOutputStream { public: explicit EscapeStringStream(llvm::raw_ostream &WrappedStream, EscapeMap &Escape) @@ -497,15 +552,18 @@ private: }; // Custom stream to add indentation used to for rendering partials. -class AddIndentationStringStream : public raw_ostream { +class AddIndentationStringStream : public MustacheOutputStream { public: - explicit AddIndentationStringStream(llvm::raw_ostream &WrappedStream, + explicit AddIndentationStringStream(raw_ostream &WrappedStream, size_t Indentation) : Indentation(Indentation), WrappedStream(WrappedStream), - NeedsIndent(true) { + NeedsIndent(true), IsSuspended(false) { SetUnbuffered(); } + void suspendIndentation() override { IsSuspended = true; } + void resumeIndentation() override { IsSuspended = false; } + protected: void write_impl(const char *Ptr, size_t Size) override { llvm::StringRef Data(Ptr, Size); @@ -513,12 +571,15 @@ protected: Indent.resize(Indentation, ' '); for (char C : Data) { + LLVM_DEBUG(dbgs() << "[Indentation Stream] NeedsIndent:" << NeedsIndent + << ", C:'" << C << "', Indentation:" << Indentation + << "\n"); if (NeedsIndent && C != '\n') { WrappedStream << Indent; NeedsIndent = false; } WrappedStream << C; - if (C == '\n') + if (C == '\n' && !IsSuspended) NeedsIndent = true; } } @@ -527,8 +588,9 @@ protected: private: size_t Indentation; - llvm::raw_ostream &WrappedStream; + raw_ostream &WrappedStream; bool NeedsIndent; + bool IsSuspended; }; class Parser { @@ -618,6 +680,9 @@ void Parser::parseMustache(ASTNode *Parent) { } } static void toMustacheString(const json::Value &Data, raw_ostream &OS) { + LLVM_DEBUG(dbgs() << "[To Mustache String] Kind: " + << jsonKindToString(Data.kind()) << ", Data: " << Data + << "\n"); switch (Data.kind()) { case json::Value::Null: return; @@ -649,19 +714,24 @@ static void toMustacheString(const json::Value &Data, raw_ostream &OS) { } } -void ASTNode::renderRoot(const json::Value &CurrentCtx, raw_ostream &OS) { +void ASTNode::renderRoot(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { renderChild(CurrentCtx, OS); } -void ASTNode::renderText(raw_ostream &OS) { OS << Body; } +void ASTNode::renderText(MustacheOutputStream &OS) { OS << Body; } -void ASTNode::renderPartial(const json::Value &CurrentCtx, raw_ostream &OS) { +void ASTNode::renderPartial(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { + LLVM_DEBUG(dbgs() << "[Render Partial] Accessor:" << AccessorValue[0] + << ", Indentation:" << Indentation << "\n"); auto Partial = Ctx.Partials.find(AccessorValue[0]); if (Partial != Ctx.Partials.end()) renderPartial(CurrentCtx, OS, Partial->getValue().get()); } -void ASTNode::renderVariable(const json::Value &CurrentCtx, raw_ostream &OS) { +void ASTNode::renderVariable(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { auto Lambda = Ctx.Lambdas.find(AccessorValue[0]); if (Lambda != Ctx.Lambdas.end()) { renderLambdas(CurrentCtx, OS, Lambda->getValue()); @@ -672,16 +742,21 @@ void ASTNode::renderVariable(const json::Value &CurrentCtx, raw_ostream &OS) { } void ASTNode::renderUnescapeVariable(const json::Value &CurrentCtx, - raw_ostream &OS) { + MustacheOutputStream &OS) { + LLVM_DEBUG(dbgs() << "[Render UnescapeVariable] Accessor:" << AccessorValue[0] + << "\n"); auto Lambda = Ctx.Lambdas.find(AccessorValue[0]); if (Lambda != Ctx.Lambdas.end()) { renderLambdas(CurrentCtx, OS, Lambda->getValue()); } else if (const json::Value *ContextPtr = findContext()) { + OS.suspendIndentation(); toMustacheString(*ContextPtr, OS); + OS.resumeIndentation(); } } -void ASTNode::renderSection(const json::Value &CurrentCtx, raw_ostream &OS) { +void ASTNode::renderSection(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { auto SectionLambda = Ctx.SectionLambdas.find(AccessorValue[0]); if (SectionLambda != Ctx.SectionLambdas.end()) { renderSectionLambdas(CurrentCtx, OS, SectionLambda->getValue()); @@ -701,7 +776,7 @@ void ASTNode::renderSection(const json::Value &CurrentCtx, raw_ostream &OS) { } void ASTNode::renderInvertSection(const json::Value &CurrentCtx, - raw_ostream &OS) { + MustacheOutputStream &OS) { bool IsLambda = Ctx.SectionLambdas.contains(AccessorValue[0]); const json::Value *ContextPtr = findContext(); if (isContextFalsey(ContextPtr) && !IsLambda) { @@ -709,34 +784,34 @@ void ASTNode::renderInvertSection(const json::Value &CurrentCtx, } } -void ASTNode::render(const json::Value &CurrentCtx, raw_ostream &OS) { +void ASTNode::render(const llvm::json::Value &Data, MustacheOutputStream &OS) { if (Ty != Root && Ty != Text && AccessorValue.empty()) return; // Set the parent context to the incoming context so that we // can walk up the context tree correctly in findContext(). - ParentContext = &CurrentCtx; + ParentContext = &Data; switch (Ty) { case Root: - renderRoot(CurrentCtx, OS); + renderRoot(Data, OS); return; case Text: renderText(OS); return; case Partial: - renderPartial(CurrentCtx, OS); + renderPartial(Data, OS); return; case Variable: - renderVariable(CurrentCtx, OS); + renderVariable(Data, OS); return; case UnescapeVariable: - renderUnescapeVariable(CurrentCtx, OS); + renderUnescapeVariable(Data, OS); return; case Section: - renderSection(CurrentCtx, OS); + renderSection(Data, OS); return; case InvertSection: - renderInvertSection(CurrentCtx, OS); + renderInvertSection(Data, OS); return; } llvm_unreachable("Invalid ASTNode type"); @@ -781,19 +856,21 @@ const json::Value *ASTNode::findContext() { return Context; } -void ASTNode::renderChild(const json::Value &Contexts, llvm::raw_ostream &OS) { +void ASTNode::renderChild(const json::Value &Contexts, + MustacheOutputStream &OS) { for (AstPtr &Child : Children) Child->render(Contexts, OS); } -void ASTNode::renderPartial(const json::Value &Contexts, llvm::raw_ostream &OS, - ASTNode *Partial) { +void ASTNode::renderPartial(const json::Value &Contexts, + MustacheOutputStream &OS, ASTNode *Partial) { + LLVM_DEBUG(dbgs() << "[Render Partial Indentation] Indentation: " << Indentation << "\n"); AddIndentationStringStream IS(OS, Indentation); Partial->render(Contexts, IS); } -void ASTNode::renderLambdas(const json::Value &Contexts, llvm::raw_ostream &OS, - Lambda &L) { +void ASTNode::renderLambdas(const json::Value &Contexts, + MustacheOutputStream &OS, Lambda &L) { json::Value LambdaResult = L(); std::string LambdaStr; raw_string_ostream Output(LambdaStr); @@ -810,7 +887,7 @@ void ASTNode::renderLambdas(const json::Value &Contexts, llvm::raw_ostream &OS, } void ASTNode::renderSectionLambdas(const json::Value &Contexts, - llvm::raw_ostream &OS, SectionLambda &L) { + MustacheOutputStream &OS, SectionLambda &L) { json::Value Return = L(RawBody); if (isFalsey(Return)) return; @@ -823,7 +900,8 @@ void ASTNode::renderSectionLambdas(const json::Value &Contexts, } void Template::render(const json::Value &Data, llvm::raw_ostream &OS) { - Tree->render(Data, OS); + RawMustacheOutputStream MOS(OS); + Tree->render(Data, MOS); } void Template::registerPartial(std::string Name, std::string Partial) { diff --git a/llvm/lib/Support/VirtualFileSystem.cpp b/llvm/lib/Support/VirtualFileSystem.cpp index 7ff62d4..44d2ee7 100644 --- a/llvm/lib/Support/VirtualFileSystem.cpp +++ b/llvm/lib/Support/VirtualFileSystem.cpp @@ -1908,7 +1908,12 @@ private: FullPath = FS->getOverlayFileDir(); assert(!FullPath.empty() && "External contents prefix directory must exist"); - llvm::sys::path::append(FullPath, Value); + SmallString<256> AbsFullPath = Value; + if (FS->makeAbsolute(FullPath, AbsFullPath)) { + error(N, "failed to make 'external-contents' absolute"); + return nullptr; + } + FullPath = AbsFullPath; } else { FullPath = Value; } @@ -2204,7 +2209,7 @@ RedirectingFileSystem::create(std::unique_ptr<MemoryBuffer> Buffer, // FS->OverlayFileDir => /<absolute_path_to>/dummy.cache/vfs // SmallString<256> OverlayAbsDir = sys::path::parent_path(YAMLFilePath); - std::error_code EC = llvm::sys::fs::make_absolute(OverlayAbsDir); + std::error_code EC = FS->makeAbsolute(OverlayAbsDir); assert(!EC && "Overlay dir final path must be absolute"); (void)EC; FS->setOverlayFileDir(OverlayAbsDir); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9078675d..45f5235 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -18867,21 +18867,25 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming()))) return SDValue(); - unsigned NumUses = N->use_size(); + // Count the number of users which are extract_vectors. + unsigned NumExts = count_if(N->users(), [](SDNode *Use) { + return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR; + }); + auto MaskEC = N->getValueType(0).getVectorElementCount(); - if (!MaskEC.isKnownMultipleOf(NumUses)) + if (!MaskEC.isKnownMultipleOf(NumExts)) return SDValue(); - ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses); + ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts); if (ExtMinEC.getKnownMinValue() < 2) return SDValue(); - SmallVector<SDNode *> Extracts(NumUses, nullptr); + SmallVector<SDNode *> Extracts(NumExts, nullptr); for (SDNode *Use : N->users()) { if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR) - return SDValue(); + continue; - // Ensure the extract type is correct (e.g. if NumUses is 4 and + // Ensure the extract type is correct (e.g. if NumExts is 4 and // the mask return type is nxv8i1, each extract should be nxv2i1. if (Use->getValueType(0).getVectorElementCount() != ExtMinEC) return SDValue(); @@ -18902,32 +18906,39 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SDValue Idx = N->getOperand(0); SDValue TC = N->getOperand(1); - EVT OpVT = Idx.getValueType(); - if (OpVT != MVT::i64) { + if (Idx.getValueType() != MVT::i64) { Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx); TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC); } // Create the whilelo_x2 intrinsics from each pair of extracts EVT ExtVT = Extracts[0]->getValueType(0); + EVT DoubleExtVT = ExtVT.getDoubleNumVectorElementsVT(*DAG.getContext()); auto R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC}); DCI.CombineTo(Extracts[0], R.getValue(0)); DCI.CombineTo(Extracts[1], R.getValue(1)); + SmallVector<SDValue> Concats = {DAG.getNode( + ISD::CONCAT_VECTORS, DL, DoubleExtVT, R.getValue(0), R.getValue(1))}; - if (NumUses == 2) - return SDValue(N, 0); + if (NumExts == 2) { + assert(N->getValueType(0) == DoubleExtVT); + return Concats[0]; + } - auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2); - for (unsigned I = 2; I < NumUses; I += 2) { + auto Elts = + DAG.getElementCount(DL, MVT::i64, ExtVT.getVectorElementCount() * 2); + for (unsigned I = 2; I < NumExts; I += 2) { // After the first whilelo_x2, we need to increment the starting value. - Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts); + Idx = DAG.getNode(ISD::UADDSAT, DL, MVT::i64, Idx, Elts); R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC}); DCI.CombineTo(Extracts[I], R.getValue(0)); DCI.CombineTo(Extracts[I + 1], R.getValue(1)); + Concats.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, DoubleExtVT, + R.getValue(0), R.getValue(1))); } - return SDValue(N, 0); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Concats); } // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index d8072d1..e472e7d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -303,6 +303,16 @@ public: bool shouldFoldConstantShiftPairToMask(const SDNode *N, CombineLevel Level) const override; + /// Return true if it is profitable to fold a pair of shifts into a mask. + bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override { + EVT VT = Y.getValueType(); + + if (VT.isVector()) + return false; + + return VT.getScalarSizeInBits() <= 64; + } + bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X, SDValue Y) const override; diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp index 7947469..09b3643 100644 --- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp +++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp @@ -541,6 +541,13 @@ void AArch64PrologueEmitter::emitPrologue() { // to determine the end of the prologue. DebugLoc DL; + // In some cases, particularly with CallingConv::SwiftTail, it is possible to + // have a tail-call where the caller only needs to adjust the stack pointer in + // the epilogue. In this case, we still need to emit a SEH prologue sequence. + // See `seh-minimal-prologue-epilogue.ll` test cases. + if (AFI->getArgumentStackToRestore()) + HasWinCFI = true; + if (AFI->shouldSignReturnAddress(MF)) { // If pac-ret+leaf is in effect, PAUTH_PROLOGUE pseudo instructions // are inserted by emitPacRetPlusLeafHardening(). diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp index cced0fa..4749748 100644 --- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp @@ -22,7 +22,7 @@ // To handle ZA state across control flow, we make use of edge bundling. This // assigns each block an "incoming" and "outgoing" edge bundle (representing // incoming and outgoing edges). Initially, these are unique to each block; -// then, in the process of forming bundles, the outgoing block of a block is +// then, in the process of forming bundles, the outgoing bundle of a block is // joined with the incoming bundle of all successors. The result is that each // bundle can be assigned a single ZA state, which ensures the state required by // all a blocks' successors is the same, and that each basic block will always diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td index eaa1870..7003a40 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPU.td +++ b/llvm/lib/Target/AMDGPU/AMDGPU.td @@ -2589,6 +2589,8 @@ def NotHasTrue16BitInsts : True16PredicateClass<"!Subtarget->hasTrue16BitInsts() // only allow 32-bit registers in operands and use low halves thereof. def UseRealTrue16Insts : True16PredicateClass<"Subtarget->useRealTrue16Insts()">, AssemblerPredicate<(all_of FeatureTrue16BitInsts, FeatureRealTrue16Insts)>; +def NotUseRealTrue16Insts : True16PredicateClass<"!Subtarget->useRealTrue16Insts()">, + AssemblerPredicate<(not (all_of FeatureTrue16BitInsts, FeatureRealTrue16Insts))>; def UseFakeTrue16Insts : True16PredicateClass<"Subtarget->hasTrue16BitInsts() && " "!Subtarget->useRealTrue16Insts()">, AssemblerPredicate<(all_of FeatureTrue16BitInsts)>; diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp index 0776d14..f413bbc 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp @@ -840,7 +840,9 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST, .Any({{B128, Ptr32}, {{}, {VgprB128, VgprPtr32}}}); // clang-format on - addRulesForGOpcs({G_AMDGPU_BUFFER_LOAD}, StandardB) + addRulesForGOpcs({G_AMDGPU_BUFFER_LOAD, G_AMDGPU_BUFFER_LOAD_FORMAT, + G_AMDGPU_TBUFFER_LOAD_FORMAT}, + StandardB) .Div(B32, {{VgprB32}, {SgprV4S32_WF, Vgpr32, Vgpr32, Sgpr32_WF}}) .Uni(B32, {{UniInVgprB32}, {SgprV4S32_WF, Vgpr32, Vgpr32, Sgpr32_WF}}) .Div(B64, {{VgprB64}, {SgprV4S32_WF, Vgpr32, Vgpr32, Sgpr32_WF}}) diff --git a/llvm/lib/Target/AMDGPU/DSInstructions.td b/llvm/lib/Target/AMDGPU/DSInstructions.td index f2e432f..b2ff5a1 100644 --- a/llvm/lib/Target/AMDGPU/DSInstructions.td +++ b/llvm/lib/Target/AMDGPU/DSInstructions.td @@ -969,10 +969,9 @@ multiclass DSReadPat_t16<DS_Pseudo inst, ValueType vt, string frag> { } let OtherPredicates = [NotLDSRequiresM0Init] in { - foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in - let True16Predicate = p in { - def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; - } + let True16Predicate = NotUseRealTrue16Insts in { + def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; + } let True16Predicate = UseRealTrue16Insts in { def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_t16"), vt, !cast<PatFrag>(frag)>; } @@ -1050,10 +1049,9 @@ multiclass DSWritePat_t16 <DS_Pseudo inst, ValueType vt, string frag> { } let OtherPredicates = [NotLDSRequiresM0Init] in { - foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in - let True16Predicate = p in { - def : DSWritePat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; - } + let True16Predicate = NotUseRealTrue16Insts in { + def : DSWritePat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; + } let True16Predicate = UseRealTrue16Insts in { def : DSWritePat<!cast<DS_Pseudo>(!cast<string>(inst)#"_t16"), vt, !cast<PatFrag>(frag)>; } diff --git a/llvm/lib/Target/AMDGPU/FLATInstructions.td b/llvm/lib/Target/AMDGPU/FLATInstructions.td index 9f33bac..5a22b23 100644 --- a/llvm/lib/Target/AMDGPU/FLATInstructions.td +++ b/llvm/lib/Target/AMDGPU/FLATInstructions.td @@ -1982,8 +1982,7 @@ defm : FlatLoadPats <FLAT_LOAD_SSHORT, sextloadi16_flat, i32>; defm : FlatLoadPats <FLAT_LOAD_SSHORT, atomic_load_sext_16_flat, i32>; defm : FlatLoadPats <FLAT_LOAD_DWORDX3, load_flat, v3i32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { defm : FlatLoadPats <FLAT_LOAD_UBYTE, extloadi8_flat, i16>; defm : FlatLoadPats <FLAT_LOAD_UBYTE, zextloadi8_flat, i16>; defm : FlatLoadPats <FLAT_LOAD_SBYTE, sextloadi8_flat, i16>; @@ -2127,8 +2126,7 @@ defm : GlobalFLATLoadPats <GLOBAL_LOAD_USHORT, extloadi16_global, i32>; defm : GlobalFLATLoadPats <GLOBAL_LOAD_USHORT, zextloadi16_global, i32>; defm : GlobalFLATLoadPats <GLOBAL_LOAD_SSHORT, sextloadi16_global, i32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { defm : GlobalFLATLoadPats <GLOBAL_LOAD_UBYTE, extloadi8_global, i16>; defm : GlobalFLATLoadPats <GLOBAL_LOAD_UBYTE, zextloadi8_global, i16>; defm : GlobalFLATLoadPats <GLOBAL_LOAD_SBYTE, sextloadi8_global, i16>; @@ -2187,8 +2185,7 @@ defm : GlobalFLATStorePats <GLOBAL_STORE_BYTE, truncstorei8_global, i32>; defm : GlobalFLATStorePats <GLOBAL_STORE_SHORT, truncstorei16_global, i32>; defm : GlobalFLATStorePats <GLOBAL_STORE_DWORDX3, store_global, v3i32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let OtherPredicates = [HasFlatGlobalInsts], True16Predicate = p in { +let OtherPredicates = [HasFlatGlobalInsts], True16Predicate = NotUseRealTrue16Insts in { defm : GlobalFLATStorePats <GLOBAL_STORE_BYTE, truncstorei8_global, i16>; defm : GlobalFLATStorePats <GLOBAL_STORE_SHORT, store_global, i16>; defm : GlobalFLATStorePats <GLOBAL_STORE_BYTE, atomic_store_8_global, i16>; @@ -2356,8 +2353,7 @@ defm : ScratchFLATLoadPats <SCRATCH_LOAD_USHORT, extloadi16_private, i32>; defm : ScratchFLATLoadPats <SCRATCH_LOAD_USHORT, zextloadi16_private, i32>; defm : ScratchFLATLoadPats <SCRATCH_LOAD_SSHORT, sextloadi16_private, i32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { defm : ScratchFLATLoadPats <SCRATCH_LOAD_UBYTE, extloadi8_private, i16>; defm : ScratchFLATLoadPats <SCRATCH_LOAD_UBYTE, zextloadi8_private, i16>; defm : ScratchFLATLoadPats <SCRATCH_LOAD_SBYTE, sextloadi8_private, i16>; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h index 31a2d55..c2252af 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h @@ -1006,9 +1006,8 @@ public: Opcode == AMDGPU::S_BARRIER_INIT_M0 || Opcode == AMDGPU::S_BARRIER_INIT_IMM || Opcode == AMDGPU::S_BARRIER_JOIN_IMM || - Opcode == AMDGPU::S_BARRIER_LEAVE || - Opcode == AMDGPU::S_BARRIER_LEAVE_IMM || - Opcode == AMDGPU::DS_GWS_INIT || Opcode == AMDGPU::DS_GWS_BARRIER; + Opcode == AMDGPU::S_BARRIER_LEAVE || Opcode == AMDGPU::DS_GWS_INIT || + Opcode == AMDGPU::DS_GWS_BARRIER; } static bool isF16PseudoScalarTrans(unsigned Opcode) { diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td index 59fd2f1..be084a9 100644 --- a/llvm/lib/Target/AMDGPU/SIInstructions.td +++ b/llvm/lib/Target/AMDGPU/SIInstructions.td @@ -1466,8 +1466,7 @@ class VOPSelectPat_t16 <ValueType vt> : GCNPat < def : VOPSelectModsPat <i32>; def : VOPSelectModsPat <f32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : VOPSelectPat <f16>; def : VOPSelectPat <i16>; } // End True16Predicate = p @@ -2137,8 +2136,7 @@ def : GCNPat < >; foreach fp16vt = [f16, bf16] in { -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let SubtargetPredicate = p in { +let SubtargetPredicate = NotUseRealTrue16Insts in { def : GCNPat < (fabs (fp16vt VGPR_32:$src)), (V_AND_B32_e64 (S_MOV_B32 (i32 0x00007fff)), VGPR_32:$src) @@ -2230,8 +2228,7 @@ def : GCNPat < } foreach fp16vt = [f16, bf16] in { -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (fcopysign fp16vt:$src0, fp16vt:$src1), (V_BFI_B32_e64 (S_MOV_B32 (i32 0x00007fff)), $src0, $src1) @@ -2354,23 +2351,21 @@ def : GCNPat < (S_MOV_B32 $ga) >; -foreach pred = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in { - let True16Predicate = pred in { - def : GCNPat < - (VGPRImm<(i16 imm)>:$imm), - (V_MOV_B32_e32 imm:$imm) - >; +let True16Predicate = NotUseRealTrue16Insts in { + def : GCNPat < + (VGPRImm<(i16 imm)>:$imm), + (V_MOV_B32_e32 imm:$imm) + >; - // FIXME: Workaround for ordering issue with peephole optimizer where - // a register class copy interferes with immediate folding. Should - // use s_mov_b32, which can be shrunk to s_movk_i32 + // FIXME: Workaround for ordering issue with peephole optimizer where + // a register class copy interferes with immediate folding. Should + // use s_mov_b32, which can be shrunk to s_movk_i32 - foreach vt = [f16, bf16] in { - def : GCNPat < - (VGPRImm<(vt fpimm)>:$imm), - (V_MOV_B32_e32 (vt (bitcast_fpimm_to_i32 $imm))) - >; - } + foreach vt = [f16, bf16] in { + def : GCNPat < + (VGPRImm<(vt fpimm)>:$imm), + (V_MOV_B32_e32 (vt (bitcast_fpimm_to_i32 $imm))) + >; } } @@ -2859,8 +2854,7 @@ def : GCNPat< (i32 (DivergentSextInreg<i1> i32:$src)), (V_BFE_I32_e64 i32:$src, (i32 0), (i32 1))>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (i16 (DivergentSextInreg<i1> i16:$src)), (V_BFE_I32_e64 $src, (i32 0), (i32 1)) @@ -3205,8 +3199,7 @@ def : GCNPat< } } // AddedComplexity = 1 -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat< (i32 (DivergentUnaryFrag<zext> i16:$src)), (V_AND_B32_e64 (S_MOV_B32 (i32 0xffff)), $src) @@ -3416,8 +3409,7 @@ def : GCNPat < // Magic number: 1 | (0 << 8) | (12 << 16) | (12 << 24) // The 12s emit 0s. -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (i16 (bswap i16:$a)), (V_PERM_B32_e64 (i32 0), VSrc_b32:$a, (S_MOV_B32 (i32 0x0c0c0001))) @@ -3670,8 +3662,7 @@ def : GCNPat < (S_LSHL_B32 SReg_32:$src1, (i16 16)) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (v2i16 (DivergentBinFrag<build_vector> (i16 0), (i16 VGPR_32:$src1))), (v2i16 (V_LSHLREV_B32_e64 (i16 16), VGPR_32:$src1)) @@ -3707,8 +3698,7 @@ def : GCNPat < (COPY_TO_REGCLASS SReg_32:$src0, SReg_32) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (vecTy (DivergentBinFrag<build_vector> (Ty VGPR_32:$src0), (Ty undef))), (COPY_TO_REGCLASS VGPR_32:$src0, VGPR_32) @@ -3735,8 +3725,7 @@ def : GCNPat < >; let SubtargetPredicate = HasVOP3PInsts in { -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in def : GCNPat < (v2i16 (DivergentBinFrag<build_vector> (i16 VGPR_32:$src0), (i16 VGPR_32:$src1))), (v2i16 (V_LSHL_OR_B32_e64 $src1, (i32 16), (i32 (V_AND_B32_e64 (i32 (V_MOV_B32_e32 (i32 0xffff))), $src0)))) @@ -3766,8 +3755,7 @@ def : GCNPat < (S_PACK_LL_B32_B16 SReg_32:$src0, SReg_32:$src1) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { // Take the lower 16 bits from each VGPR_32 and concat them def : GCNPat < (vecTy (DivergentBinFrag<build_vector> (Ty VGPR_32:$a), (Ty VGPR_32:$b))), @@ -3838,8 +3826,7 @@ def : GCNPat < >; // Take the upper 16 bits from each VGPR_32 and concat them -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in def : GCNPat < (vecTy (DivergentBinFrag<build_vector> (Ty !if(!eq(Ty, i16), @@ -3881,8 +3868,7 @@ def : GCNPat < (v2i16 (S_PACK_HL_B32_B16 SReg_32:$src0, SReg_32:$src1)) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (v2f16 (scalar_to_vector f16:$src0)), (COPY $src0) diff --git a/llvm/lib/Target/AMDGPU/SOPInstructions.td b/llvm/lib/Target/AMDGPU/SOPInstructions.td index 296ce5a..b3fd8c7 100644 --- a/llvm/lib/Target/AMDGPU/SOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/SOPInstructions.td @@ -1616,7 +1616,8 @@ def S_BARRIER_WAIT : SOPP_Pseudo <"s_barrier_wait", (ins i16imm:$simm16), "$simm let isConvergent = 1; } -def S_BARRIER_LEAVE : SOPP_Pseudo <"s_barrier_leave", (ins)> { + def S_BARRIER_LEAVE : SOPP_Pseudo <"s_barrier_leave", + (ins), "", [(int_amdgcn_s_barrier_leave (i16 srcvalue))] > { let SchedRW = [WriteBarrier]; let simm16 = 0; let fixed_imm = 1; @@ -1624,9 +1625,6 @@ def S_BARRIER_LEAVE : SOPP_Pseudo <"s_barrier_leave", (ins)> { let Defs = [SCC]; } -def S_BARRIER_LEAVE_IMM : SOPP_Pseudo <"s_barrier_leave", - (ins i16imm:$simm16), "$simm16", [(int_amdgcn_s_barrier_leave timm:$simm16)]>; - def S_WAKEUP : SOPP_Pseudo <"s_wakeup", (ins) > { let SubtargetPredicate = isGFX8Plus; let simm16 = 0; diff --git a/llvm/lib/Target/AMDGPU/VOP1Instructions.td b/llvm/lib/Target/AMDGPU/VOP1Instructions.td index 6230c17..77df721 100644 --- a/llvm/lib/Target/AMDGPU/VOP1Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP1Instructions.td @@ -1561,8 +1561,7 @@ def : GCNPat < } // End OtherPredicates = [isGFX8Plus] -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let OtherPredicates = [isGFX8Plus, p] in { +let OtherPredicates = [isGFX8Plus, NotUseRealTrue16Insts] in { def : GCNPat< (i32 (anyext i16:$src)), (COPY $src) diff --git a/llvm/lib/Target/AMDGPU/VOP2Instructions.td b/llvm/lib/Target/AMDGPU/VOP2Instructions.td index 37d92bc..30dab55 100644 --- a/llvm/lib/Target/AMDGPU/VOP2Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP2Instructions.td @@ -1378,8 +1378,7 @@ class ZExt_i16_i1_Pat <SDNode ext> : GCNPat < $src) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (and i16:$src0, i16:$src1), (V_AND_B32_e64 VSrc_b32:$src0, VSrc_b32:$src1) diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index e6a7c35..4a2b54d 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -387,8 +387,7 @@ let SchedRW = [Write64Bit] in { } // End SchedRW = [Write64Bit] } // End isReMaterializable = 1 -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in def : GCNPat< (i32 (DivergentUnaryFrag<sext> i16:$src)), (i32 (V_BFE_I32_e64 i16:$src, (i32 0), (i32 0x10))) @@ -501,8 +500,7 @@ def V_INTERP_P1LV_F16 : VOP3Interp <"v_interp_p1lv_f16", VOP3_INTERP16<[f32, f32 } // End SubtargetPredicate = Has16BitInsts, isCommutable = 1 -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in def : GCNPat< (i64 (DivergentUnaryFrag<sext> i16:$src)), (REG_SEQUENCE VReg_64, diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index 52ee1e8..5daf860 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -402,8 +402,7 @@ defm V_FMA_MIX_F16_t16 : VOP3_VOP3PInst_t16<"v_fma_mix_f16_t16", VOP3P_Mix_Profi defm : MadFmaMixFP32Pats<fma, V_FMA_MIX_F32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in defm : MadFmaMixFP16Pats<fma, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>; let True16Predicate = UseRealTrue16Insts in defm : MadFmaMixFP16Pats_t16<fma, V_FMA_MIX_F16_t16>; @@ -428,8 +427,7 @@ defm V_FMA_MIX_BF16_t16 : VOP3_VOP3PInst_t16<"v_fma_mix_bf16_t16", VOP3P_Mix_Pro } // End isCommutable = 1 defm : MadFmaMixFP32Pats<fma, V_FMA_MIX_F32_BF16, bf16>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in defm : MadFmaMixFP16Pats<fma, V_FMA_MIXLO_BF16, V_FMA_MIXHI_BF16, bf16, v2bf16>; let True16Predicate = UseRealTrue16Insts in defm : MadFmaMixFP16Pats_t16<fma, V_FMA_MIX_BF16_t16>; diff --git a/llvm/lib/Target/ARM/ARMBaseRegisterInfo.cpp b/llvm/lib/Target/ARM/ARMBaseRegisterInfo.cpp index e94220a..2e8a676 100644 --- a/llvm/lib/Target/ARM/ARMBaseRegisterInfo.cpp +++ b/llvm/lib/Target/ARM/ARMBaseRegisterInfo.cpp @@ -960,17 +960,3 @@ bool ARMBaseRegisterInfo::shouldCoalesce(MachineInstr *MI, } return false; } - -bool ARMBaseRegisterInfo::shouldRewriteCopySrc(const TargetRegisterClass *DefRC, - unsigned DefSubReg, - const TargetRegisterClass *SrcRC, - unsigned SrcSubReg) const { - // We can't extract an SPR from an arbitary DPR (as opposed to a DPR_VFP2). - if (DefRC == &ARM::SPRRegClass && DefSubReg == 0 && - SrcRC == &ARM::DPRRegClass && - (SrcSubReg == ARM::ssub_0 || SrcSubReg == ARM::ssub_1)) - return false; - - return TargetRegisterInfo::shouldRewriteCopySrc(DefRC, DefSubReg, - SrcRC, SrcSubReg); -} diff --git a/llvm/lib/Target/ARM/ARMBaseRegisterInfo.h b/llvm/lib/Target/ARM/ARMBaseRegisterInfo.h index 5b67b34..03b0fa0 100644 --- a/llvm/lib/Target/ARM/ARMBaseRegisterInfo.h +++ b/llvm/lib/Target/ARM/ARMBaseRegisterInfo.h @@ -158,11 +158,6 @@ public: const TargetRegisterClass *NewRC, LiveIntervals &LIS) const override; - bool shouldRewriteCopySrc(const TargetRegisterClass *DefRC, - unsigned DefSubReg, - const TargetRegisterClass *SrcRC, - unsigned SrcSubReg) const override; - int getSEHRegNum(unsigned i) const { return getEncodingValue(i); } }; diff --git a/llvm/lib/Target/PowerPC/PPCInstrFuture.td b/llvm/lib/Target/PowerPC/PPCInstrFuture.td index c3ab965..1aefea1 100644 --- a/llvm/lib/Target/PowerPC/PPCInstrFuture.td +++ b/llvm/lib/Target/PowerPC/PPCInstrFuture.td @@ -182,10 +182,113 @@ class XX3Form_XTAB6<bits<6> opcode, bits<8> xo, dag OOL, dag IOL, string asmstr, let Inst{31} = XT{5}; } +class XX3Form_XTAB6_S<bits<5> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : I<59, OOL, IOL, asmstr, NoItinerary> { + bits<6> XT; + bits<6> XA; + bits<6> XB; + + let Pattern = pattern; + + let Inst{6...10} = XT{4...0}; + let Inst{11...15} = XA{4...0}; + let Inst{16...20} = XB{4...0}; + let Inst{24...28} = xo; + let Inst{29} = XA{5}; + let Inst{30} = XB{5}; + let Inst{31} = XT{5}; +} + +class XX3Form_XTAB6_S3<bits<5> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : XX3Form_XTAB6_S<xo, OOL, IOL, asmstr, pattern> { + + bits<3> S; + let Inst{21...23} = S; +} + +class XX3Form_XTAB6_3S1<bits<5> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : XX3Form_XTAB6_S<xo, OOL, IOL, asmstr, pattern> { + + bits<1> S0; + bits<1> S1; + bits<1> S2; + + let Inst{21} = S0; + let Inst{22} = S1; + let Inst{23} = S2; +} + +class XX3Form_XTAB6_2S1<bits<5> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : XX3Form_XTAB6_S<xo, OOL, IOL, asmstr, pattern> { + + bits<1> S1; + bits<1> S2; + + let Inst{21} = 0; + let Inst{22} = S1; + let Inst{23} = S2; +} + +class XX3Form_XTAB6_P<bits<7> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : I<59, OOL, IOL, asmstr, NoItinerary> { + + bits<6> XT; + bits<6> XA; + bits<6> XB; + bits<1> P; + + let Pattern = pattern; + + let Inst{6...10} = XT{4...0}; + let Inst{11...15} = XA{4...0}; + let Inst{16...20} = XB{4...0}; + let Inst{21} = P; + let Inst{22...28} = xo; + let Inst{29} = XA{5}; + let Inst{30} = XB{5}; + let Inst{31} = XT{5}; +} + +// Prefix instruction classes. + +class 8RR_XX4Form_XTABC6_P<bits<6> opcode, dag OOL, dag IOL, string asmstr, + InstrItinClass itin, list<dag> pattern> + : PI<1, opcode, OOL, IOL, asmstr, itin> { + bits<6> XT; + bits<6> XA; + bits<6> XB; + bits<6> XC; + bits<1> P; + + let Pattern = pattern; + + // The prefix. + let Inst{6...7} = 1; + let Inst{8...11} = 0; + + // The instruction. + let Inst{38...42} = XT{4...0}; + let Inst{43...47} = XA{4...0}; + let Inst{48...52} = XB{4...0}; + let Inst{53...57} = XC{4...0}; + let Inst{58} = 1; + let Inst{59} = P; + let Inst{60} = XC{5}; + let Inst{61} = XA{5}; + let Inst{62} = XB{5}; + let Inst{63} = XT{5}; +} + //-------------------------- Instruction definitions -------------------------// // Predicate combinations available: // [IsISAFuture] // [HasVSX, IsISAFuture] +// [HasVSX, PrefixInstrs, IsISAFuture] let Predicates = [IsISAFuture] in { defm SUBFUS : XOForm_RTAB5_L1r<31, 72, (outs g8rc:$RT), @@ -294,6 +397,78 @@ let Predicates = [HasVSX, IsISAFuture] in { "xvmulhuw $XT, $XA, $XB", []>; def XVMULHUH: XX3Form_XTAB6<60, 122, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), "xvmulhuh $XT, $XA, $XB", []>; + + // Elliptic Curve Cryptography Acceleration Instructions. + def XXMULMUL + : XX3Form_XTAB6_S3<1, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB, u3imm:$S), + "xxmulmul $XT, $XA, $XB, $S", []>; + def XXMULMULHIADD + : XX3Form_XTAB6_3S1<9, (outs vsrc:$XT), + (ins vsrc:$XA, vsrc:$XB, u1imm:$S0, u1imm:$S1, + u1imm:$S2), + "xxmulmulhiadd $XT, $XA, $XB, $S0, $S1, $S2", []>; + def XXMULMULLOADD + : XX3Form_XTAB6_2S1<17, (outs vsrc:$XT), + (ins vsrc:$XA, vsrc:$XB, u1imm:$S1, u1imm:$S2), + "xxmulmulloadd $XT, $XA, $XB, $S1, $S2", []>; + def XXSSUMUDM + : XX3Form_XTAB6_P<25, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB, u1imm:$P), + "xxssumudm $XT, $XA, $XB, $P", []>; + def XXSSUMUDMC + : XX3Form_XTAB6_P<57, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB, u1imm:$P), + "xxssumudmc $XT, $XA, $XB, $P", []>; + def XSADDADDUQM + : XX3Form_XTAB6<59, 96, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsaddadduqm $XT, $XA, $XB", []>; + def XSADDADDSUQM + : XX3Form_XTAB6<59, 104, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsaddaddsuqm $XT, $XA, $XB", []>; + def XSADDSUBUQM + : XX3Form_XTAB6<59, 112, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsaddsubuqm $XT, $XA, $XB", []>; + def XSADDSUBSUQM + : XX3Form_XTAB6<59, 224, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsaddsubsuqm $XT, $XA, $XB", []>; + def XSMERGE2T1UQM + : XX3Form_XTAB6<59, 232, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsmerge2t1uqm $XT, $XA, $XB", []>; + def XSMERGE2T2UQM + : XX3Form_XTAB6<59, 240, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsmerge2t2uqm $XT, $XA, $XB", []>; + def XSMERGE2T3UQM + : XX3Form_XTAB6<59, 89, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsmerge2t3uqm $XT, $XA, $XB", []>; + def XSMERGE3T1UQM + : XX3Form_XTAB6<59, 121, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsmerge3t1uqm $XT, $XA, $XB", []>; + def XSREBASE2T1UQM + : XX3Form_XTAB6<59, 145, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase2t1uqm $XT, $XA, $XB", []>; + def XSREBASE2T2UQM + : XX3Form_XTAB6<59, 177, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase2t2uqm $XT, $XA, $XB", []>; + def XSREBASE2T3UQM + : XX3Form_XTAB6<59, 209, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase2t3uqm $XT, $XA, $XB", []>; + def XSREBASE2T4UQM + : XX3Form_XTAB6<59, 217, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase2t4uqm $XT, $XA, $XB", []>; + def XSREBASE3T1UQM + : XX3Form_XTAB6<59, 241, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase3t1uqm $XT, $XA, $XB", []>; + def XSREBASE3T2UQM + : XX3Form_XTAB6<59, 249, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase3t2uqm $XT, $XA, $XB", []>; + def XSREBASE3T3UQM + : XX3Form_XTAB6<59, 195, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase3t3uqm $XT, $XA, $XB", []>; +} + +let Predicates = [HasVSX, PrefixInstrs, IsISAFuture] in { + def XXSSUMUDMCEXT + : 8RR_XX4Form_XTABC6_P< + 34, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB, vsrc:$XC, u1imm:$P), + "xxssumudmcext $XT, $XA, $XB, $XC, $P", IIC_VecGeneral, []>; } //---------------------------- Anonymous Patterns ----------------------------// diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index c2a6e51..b765fec 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -81,6 +81,7 @@ public: void outputExecutionMode(const Module &M); void outputAnnotations(const Module &M); void outputModuleSections(); + void outputFPFastMathDefaultInfo(); bool isHidden() { return MF->getFunction() .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) @@ -498,11 +499,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { for (unsigned i = 0; i < Node->getNumOperands(); i++) { + // If SPV_KHR_float_controls2 is enabled and we find any of + // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution + // modes, skip it, it'll be done somewhere else. + if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + const auto EM = + cast<ConstantInt>( + cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1)) + ->getValue()) + ->getZExtValue(); + if (EM == SPIRV::ExecutionMode::FPFastMathDefault || + EM == SPIRV::ExecutionMode::ContractionOff || + EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) + continue; + } + MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI); outputMCInst(Inst); } + outputFPFastMathDefaultInfo(); } for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { const Function &F = *FI; @@ -552,12 +569,84 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { } if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") && !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { - MCInst Inst; - Inst.setOpcode(SPIRV::OpExecutionMode); - Inst.addOperand(MCOperand::createReg(FReg)); - unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); - Inst.addOperand(MCOperand::createImm(EM)); - outputMCInst(Inst); + if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + // When SPV_KHR_float_controls2 is enabled, ContractionOff is + // deprecated. We need to use FPFastMathDefault with the appropriate + // flags instead. Since FPFastMathDefault takes a target type, we need + // to emit it for each floating-point type that exists in the module + // to match the effect of ContractionOff. As of now, there are 3 FP + // types: fp16, fp32 and fp64. + + // We only end up here because there is no "spirv.ExecutionMode" + // metadata, so that means no FPFastMathDefault. Therefore, we only + // need to make sure AllowContract is set to 0, as the rest of flags. + // We still need to emit the OpExecutionMode instruction, otherwise + // it's up to the client API to define the flags. Therefore, we need + // to find the constant with 0 value. + + // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of + // type int32 with 0 value to represent the FP Fast Math Mode. + std::vector<const MachineInstr *> SPIRVFloatTypes; + const MachineInstr *ConstZero = nullptr; + for (const MachineInstr *MI : + MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { + // Skip if the instruction is not OpTypeFloat or OpConstant. + unsigned OpCode = MI->getOpcode(); + if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull) + continue; + + // Collect the SPIRV type if it's a float. + if (OpCode == SPIRV::OpTypeFloat) { + // Skip if the target type is not fp16, fp32, fp64. + const unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); + if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 && + OpTypeFloatSize != 64) { + continue; + } + SPIRVFloatTypes.push_back(MI); + } else { + // Check if the constant is int32, if not skip it. + const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); + MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); + if (!TypeMI || TypeMI->getOperand(1).getImm() != 32) + continue; + + ConstZero = MI; + } + } + + // When SPV_KHR_float_controls2 is enabled, ContractionOff is + // deprecated. We need to use FPFastMathDefault with the appropriate + // flags instead. Since FPFastMathDefault takes a target type, we need + // to emit it for each floating-point type that exists in the module + // to match the effect of ContractionOff. As of now, there are 3 FP + // types: fp16, fp32 and fp64. + for (const MachineInstr *MI : SPIRVFloatTypes) { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionModeId); + Inst.addOperand(MCOperand::createReg(FReg)); + unsigned EM = + static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault); + Inst.addOperand(MCOperand::createImm(EM)); + const MachineFunction *MF = MI->getMF(); + MCRegister TypeReg = + MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(TypeReg)); + assert(ConstZero && "There should be a constant zero."); + MCRegister ConstReg = MAI->getRegisterAlias( + ConstZero->getMF(), ConstZero->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(ConstReg)); + outputMCInst(Inst); + } + } else { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + Inst.addOperand(MCOperand::createReg(FReg)); + unsigned EM = + static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); + Inst.addOperand(MCOperand::createImm(EM)); + outputMCInst(Inst); + } } } } @@ -606,6 +695,101 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) { } } +void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() { + // Collect the SPIRVTypes that are OpTypeFloat and the constants of type + // int32, that might be used as FP Fast Math Mode. + std::vector<const MachineInstr *> SPIRVFloatTypes; + // Hashtable to associate immediate values with the constant holding them. + std::unordered_map<int, const MachineInstr *> ConstMap; + for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { + // Skip if the instruction is not OpTypeFloat or OpConstant. + unsigned OpCode = MI->getOpcode(); + if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI && + OpCode != SPIRV::OpConstantNull) + continue; + + // Collect the SPIRV type if it's a float. + if (OpCode == SPIRV::OpTypeFloat) { + SPIRVFloatTypes.push_back(MI); + } else { + // Check if the constant is int32, if not skip it. + const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); + MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); + if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt || + TypeMI->getOperand(1).getImm() != 32) + continue; + + if (OpCode == SPIRV::OpConstantI) + ConstMap[MI->getOperand(2).getImm()] = MI; + else + ConstMap[0] = MI; + } + } + + for (const auto &[Func, FPFastMathDefaultInfoVec] : + MAI->FPFastMathDefaultInfoMap) { + if (FPFastMathDefaultInfoVec.empty()) + continue; + + for (const MachineInstr *MI : SPIRVFloatTypes) { + unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); + unsigned Index = SPIRV::FPFastMathDefaultInfoVector:: + computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize); + assert(Index < FPFastMathDefaultInfoVec.size() && + "Index out of bounds for FPFastMathDefaultInfoVec"); + const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index]; + assert(FPFastMathDefaultInfo.Ty && + "Expected target type for FPFastMathDefaultInfo"); + assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() == + OpTypeFloatSize && + "Mismatched float type size"); + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionModeId); + MCRegister FuncReg = MAI->getFuncReg(Func); + assert(FuncReg.isValid()); + Inst.addOperand(MCOperand::createReg(FuncReg)); + Inst.addOperand( + MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault)); + MCRegister TypeReg = + MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(TypeReg)); + unsigned Flags = FPFastMathDefaultInfo.FastMathFlags; + if (FPFastMathDefaultInfo.ContractionOff && + (Flags & SPIRV::FPFastMathMode::AllowContract)) + report_fatal_error( + "Conflicting FPFastMathFlags: ContractionOff and AllowContract"); + + if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve && + !(Flags & + (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | + SPIRV::FPFastMathMode::NSZ))) { + if (FPFastMathDefaultInfo.FPFastMathDefault) + report_fatal_error("Conflicting FPFastMathFlags: " + "SignedZeroInfNanPreserve but at least one of " + "NotNaN/NotInf/NSZ is enabled."); + } + + // Don't emit if none of the execution modes was used. + if (Flags == SPIRV::FPFastMathMode::None && + !FPFastMathDefaultInfo.ContractionOff && + !FPFastMathDefaultInfo.SignedZeroInfNanPreserve && + !FPFastMathDefaultInfo.FPFastMathDefault) + continue; + + // Retrieve the constant instruction for the immediate value. + auto It = ConstMap.find(Flags); + if (It == ConstMap.end()) + report_fatal_error("Expected constant instruction for FP Fast Math " + "Mode operand of FPFastMathDefault execution mode."); + const MachineInstr *ConstMI = It->second; + MCRegister ConstReg = MAI->getRegisterAlias( + ConstMI->getMF(), ConstMI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(ConstReg)); + outputMCInst(Inst); + } + } +} + void SPIRVAsmPrinter::outputModuleSections() { const Module *M = MMI->getModule(); // Get the global subtarget to output module-level info. @@ -614,7 +798,8 @@ void SPIRVAsmPrinter::outputModuleSections() { MAI = &SPIRVModuleAnalysis::MAI; assert(ST && TII && MAI && M && "Module analysis is required"); // Output instructions according to the Logical Layout of a Module: - // 1,2. All OpCapability instructions, then optional OpExtension instructions. + // 1,2. All OpCapability instructions, then optional OpExtension + // instructions. outputGlobalRequirements(); // 3. Optional OpExtInstImport instructions. outputOpExtInstImports(*M); @@ -622,7 +807,8 @@ void SPIRVAsmPrinter::outputModuleSections() { outputOpMemoryModel(); // 5. All entry point declarations, using OpEntryPoint. outputEntryPoints(); - // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId. + // 6. Execution-mode declarations, using OpExecutionMode or + // OpExecutionModeId. outputExecutionMode(*M); // 7a. Debug: all OpString, OpSourceExtension, OpSource, and // OpSourceContinued, without forward references. diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index f704d3a..0e0c454 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -1162,11 +1162,24 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) { static bool generateExtInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, - SPIRVGlobalRegistry *GR) { + SPIRVGlobalRegistry *GR, const CallBase &CB) { // Lookup the extended instruction number in the TableGen records. const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; uint32_t Number = SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number; + // fmin_common and fmax_common are now deprecated, and we should use fmin and + // fmax with NotInf and NotNaN flags instead. Keep original number to add + // later the NoNans and NoInfs flags. + uint32_t OrigNumber = Number; + const SPIRVSubtarget &ST = + cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); + if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2) && + (Number == SPIRV::OpenCLExtInst::fmin_common || + Number == SPIRV::OpenCLExtInst::fmax_common)) { + Number = (Number == SPIRV::OpenCLExtInst::fmin_common) + ? SPIRV::OpenCLExtInst::fmin + : SPIRV::OpenCLExtInst::fmax; + } // Build extended instruction. auto MIB = @@ -1178,6 +1191,13 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call, for (auto Argument : Call->Arguments) MIB.addUse(Argument); + MIB.getInstr()->copyIRFlags(CB); + if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common || + OrigNumber == SPIRV::OpenCLExtInst::fmax_common) { + // Add NoNans and NoInfs flags to fmin/fmax instruction. + MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoNans); + MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoInfs); + } return true; } @@ -2908,7 +2928,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, MachineIRBuilder &MIRBuilder, const Register OrigRet, const Type *OrigRetTy, const SmallVectorImpl<Register> &Args, - SPIRVGlobalRegistry *GR) { + SPIRVGlobalRegistry *GR, const CallBase &CB) { LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n"); // Lookup the builtin in the TableGen records. @@ -2931,7 +2951,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, // Match the builtin with implementation based on the grouping. switch (Call->Builtin->Group) { case SPIRV::Extended: - return generateExtInst(Call.get(), MIRBuilder, GR); + return generateExtInst(Call.get(), MIRBuilder, GR, CB); case SPIRV::Relational: return generateRelationalInst(Call.get(), MIRBuilder, GR); case SPIRV::Group: diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h index 1a8641a..f6a5234 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h @@ -39,7 +39,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, MachineIRBuilder &MIRBuilder, const Register OrigRet, const Type *OrigRetTy, const SmallVectorImpl<Register> &Args, - SPIRVGlobalRegistry *GR); + SPIRVGlobalRegistry *GR, const CallBase &CB); /// Helper function for finding a builtin function attributes /// by a demangled function name. Defined in SPIRVBuiltins.cpp. diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index a412887..1a7c02c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -641,9 +641,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, GR->getPointerSize())); } } - if (auto Res = - SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(), - MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR)) + if (auto Res = SPIRV::lowerBuiltin( + DemangledName, ST->getPreferredInstructionSet(), MIRBuilder, + ResVReg, OrigRetTy, ArgVRegs, GR, *Info.CB)) return *Res; } diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 704edd3..9f2e075 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -25,6 +25,7 @@ #include "llvm/IR/TypedPointerType.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> #include <queue> #include <unordered_set> @@ -152,6 +153,7 @@ class SPIRVEmitIntrinsics void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B); bool shouldTryToAddMemAliasingDecoration(Instruction *Inst); void insertSpirvDecorations(Instruction *I, IRBuilder<> &B); + void insertConstantsForFPFastMathDefault(Module &M); void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B); void processParamTypes(Function *F, IRBuilder<> &B); void processParamTypesByFunHeader(Function *F, IRBuilder<> &B); @@ -2249,6 +2251,198 @@ void SPIRVEmitIntrinsics::insertSpirvDecorations(Instruction *I, } } +static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec( + const Module &M, + DenseMap<Function *, SPIRV::FPFastMathDefaultInfoVector> + &FPFastMathDefaultInfoMap, + Function *F) { + auto it = FPFastMathDefaultInfoMap.find(F); + if (it != FPFastMathDefaultInfoMap.end()) + return it->second; + + // If the map does not contain the entry, create a new one. Initialize it to + // contain all 3 elements sorted by bit width of target type: {half, float, + // double}. + SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec; + FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()), + SPIRV::FPFastMathMode::None); + return FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec); +} + +static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo( + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, + const Type *Ty) { + size_t BitWidth = Ty->getScalarSizeInBits(); + int Index = + SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex( + BitWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + return FPFastMathDefaultInfoVec[Index]; +} + +void SPIRVEmitIntrinsics::insertConstantsForFPFastMathDefault(Module &M) { + const SPIRVSubtarget *ST = TM->getSubtargetImpl(); + if (!ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) + return; + + // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap. + // We need the entry point (function) as the key, and the target + // type and flags as the value. + // We also need to check ContractionOff and SignedZeroInfNanPreserve + // execution modes, as they are now deprecated and must be replaced + // with FPFastMathDefaultInfo. + auto Node = M.getNamedMetadata("spirv.ExecutionMode"); + if (!Node) { + if (!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { + // This requires emitting ContractionOff. However, because + // ContractionOff is now deprecated, we need to replace it with + // FPFastMathDefaultInfo with FP Fast Math Mode bitmask set to all 0. + // We need to create the constant for that. + + // Create constant instruction with the bitmask flags. + Constant *InitValue = + ConstantInt::get(Type::getInt32Ty(M.getContext()), 0); + // TODO: Reuse constant if there is one already with the required + // value. + [[maybe_unused]] GlobalVariable *GV = + new GlobalVariable(M, // Module + Type::getInt32Ty(M.getContext()), // Type + true, // isConstant + GlobalValue::InternalLinkage, // Linkage + InitValue // Initializer + ); + } + return; + } + + // The table maps function pointers to their default FP fast math info. It + // can be assumed that the SmallVector is sorted by the bit width of the + // type. The first element is the smallest bit width, and the last element + // is the largest bit width, therefore, we will have {half, float, double} + // in the order of their bit widths. + DenseMap<Function *, SPIRV::FPFastMathDefaultInfoVector> + FPFastMathDefaultInfoMap; + + for (unsigned i = 0; i < Node->getNumOperands(); i++) { + MDNode *MDN = cast<MDNode>(Node->getOperand(i)); + assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands"); + Function *F = cast<Function>( + cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue()); + const auto EM = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue()) + ->getZExtValue(); + if (EM == SPIRV::ExecutionMode::FPFastMathDefault) { + assert(MDN->getNumOperands() == 4 && + "Expected 4 operands for FPFastMathDefault"); + const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType(); + unsigned Flags = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue()) + ->getZExtValue(); + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, FPFastMathDefaultInfoMap, F); + SPIRV::FPFastMathDefaultInfo &Info = + getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T); + Info.FastMathFlags = Flags; + Info.FPFastMathDefault = true; + } else if (EM == SPIRV::ExecutionMode::ContractionOff) { + assert(MDN->getNumOperands() == 2 && + "Expected no operands for ContractionOff"); + + // We need to save this info for every possible FP type, i.e. {half, + // float, double, fp128}. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, FPFastMathDefaultInfoMap, F); + for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) { + Info.ContractionOff = true; + } + } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) { + assert(MDN->getNumOperands() == 3 && + "Expected 1 operand for SignedZeroInfNanPreserve"); + unsigned TargetWidth = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue()) + ->getZExtValue(); + // We need to save this info only for the FP type with TargetWidth. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, FPFastMathDefaultInfoMap, F); + int Index = SPIRV::FPFastMathDefaultInfoVector:: + computeFPFastMathDefaultInfoVecIndex(TargetWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true; + } + } + + std::unordered_map<unsigned, GlobalVariable *> GlobalVars; + for (auto &[Func, FPFastMathDefaultInfoVec] : FPFastMathDefaultInfoMap) { + if (FPFastMathDefaultInfoVec.empty()) + continue; + + for (const SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) { + assert(Info.Ty && "Expected target type for FPFastMathDefaultInfo"); + // Skip if none of the execution modes was used. + unsigned Flags = Info.FastMathFlags; + if (Flags == SPIRV::FPFastMathMode::None && !Info.ContractionOff && + !Info.SignedZeroInfNanPreserve && !Info.FPFastMathDefault) + continue; + + // Check if flags are compatible. + if (Info.ContractionOff && (Flags & SPIRV::FPFastMathMode::AllowContract)) + report_fatal_error("Conflicting FPFastMathFlags: ContractionOff " + "and AllowContract"); + + if (Info.SignedZeroInfNanPreserve && + !(Flags & + (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | + SPIRV::FPFastMathMode::NSZ))) { + if (Info.FPFastMathDefault) + report_fatal_error("Conflicting FPFastMathFlags: " + "SignedZeroInfNanPreserve but at least one of " + "NotNaN/NotInf/NSZ is enabled."); + } + + if ((Flags & SPIRV::FPFastMathMode::AllowTransform) && + !((Flags & SPIRV::FPFastMathMode::AllowReassoc) && + (Flags & SPIRV::FPFastMathMode::AllowContract))) { + report_fatal_error("Conflicting FPFastMathFlags: " + "AllowTransform requires AllowReassoc and " + "AllowContract to be set."); + } + + auto it = GlobalVars.find(Flags); + GlobalVariable *GV = nullptr; + if (it != GlobalVars.end()) { + // Reuse existing global variable. + GV = it->second; + } else { + // Create constant instruction with the bitmask flags. + Constant *InitValue = + ConstantInt::get(Type::getInt32Ty(M.getContext()), Flags); + // TODO: Reuse constant if there is one already with the required + // value. + GV = new GlobalVariable(M, // Module + Type::getInt32Ty(M.getContext()), // Type + true, // isConstant + GlobalValue::InternalLinkage, // Linkage + InitValue // Initializer + ); + GlobalVars[Flags] = GV; + } + } + } +} + void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I, IRBuilder<> &B) { auto *II = dyn_cast<IntrinsicInst>(I); @@ -2569,9 +2763,9 @@ GetElementPtrInst * SPIRVEmitIntrinsics::simplifyZeroLengthArrayGepInst(GetElementPtrInst *GEP) { // getelementptr [0 x T], P, 0 (zero), I -> getelementptr T, P, I. // If type is 0-length array and first index is 0 (zero), drop both the - // 0-length array type and the first index. This is a common pattern in the - // IR, e.g. when using a zero-length array as a placeholder for a flexible - // array such as unbound arrays. + // 0-length array type and the first index. This is a common pattern in + // the IR, e.g. when using a zero-length array as a placeholder for a + // flexible array such as unbound arrays. assert(GEP && "GEP is null"); Type *SrcTy = GEP->getSourceElementType(); SmallVector<Value *, 8> Indices(GEP->indices()); @@ -2633,8 +2827,9 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { processParamTypesByFunHeader(CurrF, B); - // StoreInst's operand type can be changed during the next transformations, - // so we need to store it in the set. Also store already transformed types. + // StoreInst's operand type can be changed during the next + // transformations, so we need to store it in the set. Also store already + // transformed types. for (auto &I : instructions(Func)) { StoreInst *SI = dyn_cast<StoreInst>(&I); if (!SI) @@ -2681,8 +2876,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { for (auto &I : llvm::reverse(instructions(Func))) deduceOperandElementType(&I, &IncompleteRets); - // Pass forward for PHIs only, their operands are not preceed the instruction - // in meaning of `instructions(Func)`. + // Pass forward for PHIs only, their operands are not preceed the + // instruction in meaning of `instructions(Func)`. for (BasicBlock &BB : Func) for (PHINode &Phi : BB.phis()) if (isPointerTy(Phi.getType())) @@ -2692,8 +2887,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { TrackConstants = true; if (!I->getType()->isVoidTy() || isa<StoreInst>(I)) setInsertPointAfterDef(B, I); - // Visitors return either the original/newly created instruction for further - // processing, nullptr otherwise. + // Visitors return either the original/newly created instruction for + // further processing, nullptr otherwise. I = visit(*I); if (!I) continue; @@ -2816,6 +3011,7 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) { bool Changed = false; parseFunDeclarations(M); + insertConstantsForFPFastMathDefault(M); TodoType.clear(); for (auto &F : M) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 115766c..6fd1c7e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -806,7 +806,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable( // arguments. MDNode *GVarMD = nullptr; if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr) - buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD); + buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD, ST); return Reg; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index 45e88fc..ba95ad8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -132,7 +132,8 @@ bool SPIRVInstrInfo::isHeaderInstr(const MachineInstr &MI) const { } } -bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const { +bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI, + bool KHRFloatControls2) const { switch (MI.getOpcode()) { case SPIRV::OpFAddS: case SPIRV::OpFSubS: @@ -146,6 +147,24 @@ bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const { case SPIRV::OpFRemV: case SPIRV::OpFMod: return true; + case SPIRV::OpFNegateV: + case SPIRV::OpFNegate: + case SPIRV::OpOrdered: + case SPIRV::OpUnordered: + case SPIRV::OpFOrdEqual: + case SPIRV::OpFOrdNotEqual: + case SPIRV::OpFOrdLessThan: + case SPIRV::OpFOrdLessThanEqual: + case SPIRV::OpFOrdGreaterThan: + case SPIRV::OpFOrdGreaterThanEqual: + case SPIRV::OpFUnordEqual: + case SPIRV::OpFUnordNotEqual: + case SPIRV::OpFUnordLessThan: + case SPIRV::OpFUnordLessThanEqual: + case SPIRV::OpFUnordGreaterThan: + case SPIRV::OpFUnordGreaterThanEqual: + case SPIRV::OpExtInst: + return KHRFloatControls2 ? true : false; default: return false; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h index 72d2243..4de9d6a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h @@ -36,7 +36,8 @@ public: bool isTypeDeclInstr(const MachineInstr &MI) const; bool isDecorationInstr(const MachineInstr &MI) const; bool isAliasingInstr(const MachineInstr &MI) const; - bool canUseFastMathFlags(const MachineInstr &MI) const; + bool canUseFastMathFlags(const MachineInstr &MI, + bool KHRFloatControls2) const; bool canUseNSW(const MachineInstr &MI) const; bool canUseNUW(const MachineInstr &MI) const; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 1aadd9d..273edf3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1073,7 +1073,8 @@ bool SPIRVInstructionSelector::selectExtInst(Register ResVReg, .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) .addImm(static_cast<uint32_t>(Set)) - .addImm(Opcode); + .addImm(Opcode) + .setMIFlags(I.getFlags()); const unsigned NumOps = I.getNumOperands(); unsigned Index = 1; if (Index < NumOps && @@ -2629,6 +2630,7 @@ bool SPIRVInstructionSelector::selectCmp(Register ResVReg, .addUse(GR.getSPIRVTypeID(ResType)) .addUse(Cmp0) .addUse(Cmp1) + .setMIFlags(I.getFlags()) .constrainAllUses(TII, TRI, RBI); } diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index bc159d5..dc717a6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -248,6 +248,22 @@ static InstrSignature instrToSignature(const MachineInstr &MI, Register DefReg; InstrSignature Signature{MI.getOpcode()}; for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + // The only decorations that can be applied more than once to a given <id> + // or structure member are UserSemantic(5635), CacheControlLoadINTEL (6442), + // and CacheControlStoreINTEL (6443). For all the rest of decorations, we + // will only add to the signature the Opcode, the id to which it applies, + // and the decoration id, disregarding any decoration flags. This will + // ensure that any subsequent decoration with the same id will be deemed as + // a duplicate. Then, at the call site, we will be able to handle duplicates + // in the best way. + unsigned Opcode = MI.getOpcode(); + if ((Opcode == SPIRV::OpDecorate) && i >= 2) { + unsigned DecorationID = MI.getOperand(1).getImm(); + if (DecorationID != SPIRV::Decoration::UserSemantic && + DecorationID != SPIRV::Decoration::CacheControlLoadINTEL && + DecorationID != SPIRV::Decoration::CacheControlStoreINTEL) + continue; + } const MachineOperand &MO = MI.getOperand(i); size_t h; if (MO.isReg()) { @@ -559,8 +575,54 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, MAI.setSkipEmission(&MI); InstrSignature MISign = instrToSignature(MI, MAI, true); auto FoundMI = IS.insert(std::move(MISign)); - if (!FoundMI.second) + if (!FoundMI.second) { + if (MI.getOpcode() == SPIRV::OpDecorate) { + assert(MI.getNumOperands() >= 2 && + "Decoration instructions must have at least 2 operands"); + assert(MSType == SPIRV::MB_Annotations && + "Only OpDecorate instructions can be duplicates"); + // For FPFastMathMode decoration, we need to merge the flags of the + // duplicate decoration with the original one, so we need to find the + // original instruction that has the same signature. For the rest of + // instructions, we will simply skip the duplicate. + if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode) + return; // Skip duplicates of other decorations. + + const SPIRV::InstrList &Decorations = MAI.MS[MSType]; + for (const MachineInstr *OrigMI : Decorations) { + if (instrToSignature(*OrigMI, MAI, true) == MISign) { + assert(OrigMI->getNumOperands() == MI.getNumOperands() && + "Original instruction must have the same number of operands"); + assert( + OrigMI->getNumOperands() == 3 && + "FPFastMathMode decoration must have 3 operands for OpDecorate"); + unsigned OrigFlags = OrigMI->getOperand(2).getImm(); + unsigned NewFlags = MI.getOperand(2).getImm(); + if (OrigFlags == NewFlags) + return; // No need to merge, the flags are the same. + + // Emit warning about possible conflict between flags. + unsigned FinalFlags = OrigFlags | NewFlags; + llvm::errs() + << "Warning: Conflicting FPFastMathMode decoration flags " + "in instruction: " + << *OrigMI << "Original flags: " << OrigFlags + << ", new flags: " << NewFlags + << ". They will be merged on a best effort basis, but not " + "validated. Final flags: " + << FinalFlags << "\n"; + MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI); + MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2); + OrigFlagsOp = + MachineOperand::CreateImm(static_cast<unsigned>(FinalFlags)); + return; // Merge done, so we found a duplicate; don't add it to MAI.MS + } + } + assert(false && "No original instruction found for the duplicate " + "OpDecorate, but we found one in IS."); + } return; // insert failed, so we found a duplicate; don't add it to MAI.MS + } // No duplicates, so add it. if (Append) MAI.MS[MSType].push_back(&MI); @@ -934,6 +996,11 @@ static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) { Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL); Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error); + } else if (Dec == SPIRV::Decoration::FPFastMathMode) { + if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + Reqs.addRequirements(SPIRV::Capability::FloatControls2); + Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2); + } } } @@ -1994,10 +2061,13 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, // Collect requirements for OpExecutionMode instructions. auto Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { - bool RequireFloatControls = false, RequireFloatControls2 = false, + bool RequireFloatControls = false, RequireIntelFloatControls2 = false, + RequireKHRFloatControls2 = false, VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4)); - bool HasFloatControls2 = + bool HasIntelFloatControls2 = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2); + bool HasKHRFloatControls2 = + ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); for (unsigned i = 0; i < Node->getNumOperands(); i++) { MDNode *MDN = cast<MDNode>(Node->getOperand(i)); const MDOperand &MDOp = MDN->getOperand(1); @@ -2010,7 +2080,6 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, switch (EM) { case SPIRV::ExecutionMode::DenormPreserve: case SPIRV::ExecutionMode::DenormFlushToZero: - case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: case SPIRV::ExecutionMode::RoundingModeRTE: case SPIRV::ExecutionMode::RoundingModeRTZ: RequireFloatControls = VerLower14; @@ -2021,8 +2090,28 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, case SPIRV::ExecutionMode::RoundingModeRTNINTEL: case SPIRV::ExecutionMode::FloatingPointModeALTINTEL: case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL: - if (HasFloatControls2) { - RequireFloatControls2 = true; + if (HasIntelFloatControls2) { + RequireIntelFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); + } + break; + case SPIRV::ExecutionMode::FPFastMathDefault: { + if (HasKHRFloatControls2) { + RequireKHRFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); + } + break; + } + case SPIRV::ExecutionMode::ContractionOff: + case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: + if (HasKHRFloatControls2) { + RequireKHRFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, + SPIRV::ExecutionMode::FPFastMathDefault, ST); + } else { MAI.Reqs.getAndAddRequirements( SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); } @@ -2037,8 +2126,10 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, if (RequireFloatControls && ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls)) MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls); - if (RequireFloatControls2) + if (RequireIntelFloatControls2) MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2); + if (RequireKHRFloatControls2) + MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2); } for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { const Function &F = *FI; @@ -2078,8 +2169,11 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, } } -static unsigned getFastMathFlags(const MachineInstr &I) { +static unsigned getFastMathFlags(const MachineInstr &I, + const SPIRVSubtarget &ST) { unsigned Flags = SPIRV::FPFastMathMode::None; + bool CanUseKHRFloatControls2 = + ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) Flags |= SPIRV::FPFastMathMode::NotNaN; if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) @@ -2088,12 +2182,45 @@ static unsigned getFastMathFlags(const MachineInstr &I) { Flags |= SPIRV::FPFastMathMode::NSZ; if (I.getFlag(MachineInstr::MIFlag::FmArcp)) Flags |= SPIRV::FPFastMathMode::AllowRecip; - if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) - Flags |= SPIRV::FPFastMathMode::Fast; + if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2) + Flags |= SPIRV::FPFastMathMode::AllowContract; + if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) { + if (CanUseKHRFloatControls2) + // LLVM reassoc maps to SPIRV transform, see + // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details. + // Because we are enabling AllowTransform, we must enable AllowReassoc and + // AllowContract too, as required by SPIRV spec. Also, we used to map + // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by + // replaced by turning all the other bits instead. Therefore, we're + // enabling every bit here except None and Fast. + Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | + SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip | + SPIRV::FPFastMathMode::AllowTransform | + SPIRV::FPFastMathMode::AllowReassoc | + SPIRV::FPFastMathMode::AllowContract; + else + Flags |= SPIRV::FPFastMathMode::Fast; + } + + if (CanUseKHRFloatControls2) { + // Error out if SPIRV::FPFastMathMode::Fast is enabled. + assert(!(Flags & SPIRV::FPFastMathMode::Fast) && + "SPIRV::FPFastMathMode::Fast is deprecated and should not be used " + "anymore."); + + // Error out if AllowTransform is enabled without AllowReassoc and + // AllowContract. + assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) || + ((Flags & SPIRV::FPFastMathMode::AllowReassoc && + Flags & SPIRV::FPFastMathMode::AllowContract))) && + "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and " + "AllowContract flags to be enabled as well."); + } + return Flags; } -static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) { +static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) { if (ST.isKernel()) return true; if (ST.getSPIRVVersion() < VersionTuple(1, 2)) @@ -2101,9 +2228,10 @@ static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) { return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); } -static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, - const SPIRVInstrInfo &TII, - SPIRV::RequirementHandler &Reqs) { +static void handleMIFlagDecoration( + MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII, + SPIRV::RequirementHandler &Reqs, const SPIRVGlobalRegistry *GR, + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) { if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, SPIRV::Decoration::NoSignedWrap, ST, Reqs) @@ -2119,13 +2247,53 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, buildOpDecorate(I.getOperand(0).getReg(), I, TII, SPIRV::Decoration::NoUnsignedWrap, {}); } - if (!TII.canUseFastMathFlags(I)) - return; - unsigned FMFlags = getFastMathFlags(I); - if (FMFlags == SPIRV::FPFastMathMode::None) + if (!TII.canUseFastMathFlags( + I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))) return; - if (isFastMathMathModeAvailable(ST)) { + unsigned FMFlags = getFastMathFlags(I, ST); + if (FMFlags == SPIRV::FPFastMathMode::None) { + // We also need to check if any FPFastMathDefault info was set for the + // types used in this instruction. + if (FPFastMathDefaultInfoVec.empty()) + return; + + // There are three types of instructions that can use fast math flags: + // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.) + // 2. Relational instructions (FCmp, FOrd, FUnord, etc.) + // 3. Extended instructions (ExtInst) + // For arithmetic instructions, the floating point type can be in the + // result type or in the operands, but they all must be the same. + // For the relational and logical instructions, the floating point type + // can only be in the operands 1 and 2, not the result type. Also, the + // operands must have the same type. For the extended instructions, the + // floating point type can be in the result type or in the operands. It's + // unclear if the operands and the result type must be the same. Let's + // assume they must be. Therefore, for 1. and 2., we can check the first + // operand type, and for 3. we can check the result type. + assert(I.getNumOperands() >= 3 && "Expected at least 3 operands"); + Register ResReg = I.getOpcode() == SPIRV::OpExtInst + ? I.getOperand(1).getReg() + : I.getOperand(2).getReg(); + SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF()); + const Type *Ty = GR->getTypeForSPIRVType(ResType); + Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty; + + // Match instruction type with the FPFastMathDefaultInfoVec. + bool Emit = false; + for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) { + if (Ty == Elem.Ty) { + FMFlags = Elem.FastMathFlags; + Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve || + Elem.FPFastMathDefault; + break; + } + } + + if (FMFlags == SPIRV::FPFastMathMode::None && !Emit) + return; + } + if (isFastMathModeAvailable(ST)) { Register DstReg = I.getOperand(0).getReg(); buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); @@ -2135,14 +2303,17 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, // Walk all functions and add decorations related to MI flags. static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, MachineModuleInfo *MMI, const SPIRVSubtarget &ST, - SPIRV::ModuleAnalysisInfo &MAI) { + SPIRV::ModuleAnalysisInfo &MAI, + const SPIRVGlobalRegistry *GR) { for (auto F = M.begin(), E = M.end(); F != E; ++F) { MachineFunction *MF = MMI->getMachineFunction(*F); if (!MF) continue; + for (auto &MBB : *MF) for (auto &MI : MBB) - handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); + handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR, + MAI.FPFastMathDefaultInfoMap[&(*F)]); } } @@ -2188,6 +2359,111 @@ static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR, } } +static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec( + const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) { + auto it = MAI.FPFastMathDefaultInfoMap.find(F); + if (it != MAI.FPFastMathDefaultInfoMap.end()) + return it->second; + + // If the map does not contain the entry, create a new one. Initialize it to + // contain all 3 elements sorted by bit width of target type: {half, float, + // double}. + SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec; + FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()), + SPIRV::FPFastMathMode::None); + return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec); +} + +static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo( + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, + const Type *Ty) { + size_t BitWidth = Ty->getScalarSizeInBits(); + int Index = + SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex( + BitWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + return FPFastMathDefaultInfoVec[Index]; +} + +static void collectFPFastMathDefaults(const Module &M, + SPIRV::ModuleAnalysisInfo &MAI, + const SPIRVSubtarget &ST) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) + return; + + // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap. + // We need the entry point (function) as the key, and the target + // type and flags as the value. + // We also need to check ContractionOff and SignedZeroInfNanPreserve + // execution modes, as they are now deprecated and must be replaced + // with FPFastMathDefaultInfo. + auto Node = M.getNamedMetadata("spirv.ExecutionMode"); + if (!Node) + return; + + for (unsigned i = 0; i < Node->getNumOperands(); i++) { + MDNode *MDN = cast<MDNode>(Node->getOperand(i)); + assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands"); + const Function *F = cast<Function>( + cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue()); + const auto EM = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue()) + ->getZExtValue(); + if (EM == SPIRV::ExecutionMode::FPFastMathDefault) { + assert(MDN->getNumOperands() == 4 && + "Expected 4 operands for FPFastMathDefault"); + + const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType(); + unsigned Flags = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue()) + ->getZExtValue(); + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + SPIRV::FPFastMathDefaultInfo &Info = + getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T); + Info.FastMathFlags = Flags; + Info.FPFastMathDefault = true; + } else if (EM == SPIRV::ExecutionMode::ContractionOff) { + assert(MDN->getNumOperands() == 2 && + "Expected no operands for ContractionOff"); + + // We need to save this info for every possible FP type, i.e. {half, + // float, double, fp128}. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) { + Info.ContractionOff = true; + } + } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) { + assert(MDN->getNumOperands() == 3 && + "Expected 1 operand for SignedZeroInfNanPreserve"); + unsigned TargetWidth = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue()) + ->getZExtValue(); + // We need to save this info only for the FP type with TargetWidth. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + int Index = SPIRV::FPFastMathDefaultInfoVector:: + computeFPFastMathDefaultInfoVecIndex(TargetWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true; + } + } +} + struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { @@ -2209,7 +2485,8 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) { patchPhis(M, GR, *TII, MMI); addMBBNames(M, *TII, MMI, *ST, MAI); - addDecorations(M, *TII, MMI, *ST, MAI); + collectFPFastMathDefaults(M, MAI, *ST); + addDecorations(M, *TII, MMI, *ST, MAI, GR); collectReqs(M, MAI, MMI, *ST); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h index 41c792a..d8376cd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -159,6 +159,13 @@ struct ModuleAnalysisInfo { InstrList MS[NUM_MODULE_SECTIONS]; // The table maps MBB number to SPIR-V unique ID register. DenseMap<std::pair<const MachineFunction *, int>, MCRegister> BBNumToRegMap; + // The table maps function pointers to their default FP fast math info. It can + // be assumed that the SmallVector is sorted by the bit width of the type. The + // first element is the smallest bit width, and the last element is the + // largest bit width, therefore, we will have {half, float, double} in + // the order of their bit widths. + DenseMap<const Function *, SPIRV::FPFastMathDefaultInfoVector> + FPFastMathDefaultInfoMap; MCRegister getFuncReg(const Function *F) { assert(F && "Function is null"); diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 1a08c6a..db6f2d6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -839,6 +839,7 @@ static uint32_t convertFloatToSPIRVWord(float F) { static void insertSpirvDecorations(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB) { + const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(MIB.getMF().getSubtarget()); SmallVector<MachineInstr *, 10> ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { @@ -849,7 +850,7 @@ static void insertSpirvDecorations(MachineFunction &MF, SPIRVGlobalRegistry *GR, MIB.setInsertPt(*MI.getParent(), MI.getNextNode()); if (isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration)) { buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB, - MI.getOperand(2).getMetadata()); + MI.getOperand(2).getMetadata(), ST); } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_fpmaxerror_decoration)) { ConstantFP *OpV = mdconst::dyn_extract<ConstantFP>( diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 66ce5a2..6a32dba 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -802,6 +802,7 @@ defm RoundingModeRTPINTEL : ExecutionModeOperand<5620, [RoundToInfinityINTEL]>; defm RoundingModeRTNINTEL : ExecutionModeOperand<5621, [RoundToInfinityINTEL]>; defm FloatingPointModeALTINTEL : ExecutionModeOperand<5622, [FloatingPointModeINTEL]>; defm FloatingPointModeIEEEINTEL : ExecutionModeOperand<5623, [FloatingPointModeINTEL]>; +defm FPFastMathDefault : ExecutionModeOperand<6028, [FloatControls2]>; //===----------------------------------------------------------------------===// // Multiclass used to define StorageClass enum values and at the same time @@ -1153,6 +1154,9 @@ defm NotInf : FPFastMathModeOperand<0x2, [Kernel]>; defm NSZ : FPFastMathModeOperand<0x4, [Kernel]>; defm AllowRecip : FPFastMathModeOperand<0x8, [Kernel]>; defm Fast : FPFastMathModeOperand<0x10, [Kernel]>; +defm AllowContract : FPFastMathModeOperand<0x10000, [FloatControls2]>; +defm AllowReassoc : FPFastMathModeOperand<0x20000, [FloatControls2]>; +defm AllowTransform : FPFastMathModeOperand<0x40000, [FloatControls2]>; //===----------------------------------------------------------------------===// // Multiclass used to define FPRoundingMode enum values and at the same time diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 820e56b..327c011 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -181,7 +181,7 @@ void buildOpMemberDecorate(Register Reg, MachineInstr &I, } void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, - const MDNode *GVarMD) { + const MDNode *GVarMD, const SPIRVSubtarget &ST) { for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) { auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I)); if (!OpMD) @@ -193,6 +193,20 @@ void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, if (!DecorationId) report_fatal_error("Expect SPIR-V <Decoration> operand to be the first " "element of the decoration"); + + // The goal of `spirv.Decorations` metadata is to provide a way to + // represent SPIR-V entities that do not map to LLVM in an obvious way. + // FP flags do have obvious matches between LLVM IR and SPIR-V. + // Additionally, we have no guarantee at this point that the flags passed + // through the decoration are not violated already in the optimizer passes. + // Therefore, we simply ignore FP flags, including NoContraction, and + // FPFastMathMode. + if (DecorationId->getZExtValue() == + static_cast<uint32_t>(SPIRV::Decoration::NoContraction) || + DecorationId->getZExtValue() == + static_cast<uint32_t>(SPIRV::Decoration::FPFastMathMode)) { + continue; // Ignored. + } auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) .addUse(Reg) .addImm(static_cast<uint32_t>(DecorationId->getZExtValue())); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 45c520a..409a0fd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -113,6 +113,54 @@ public: std::function<bool(BasicBlock *)> Op); }; +namespace SPIRV { +struct FPFastMathDefaultInfo { + const Type *Ty = nullptr; + unsigned FastMathFlags = 0; + // When SPV_KHR_float_controls2 ContractionOff and SignzeroInfNanPreserve are + // deprecated, and we replace them with FPFastMathDefault appropriate flags + // instead. However, we have no guarantee about the order in which we will + // process execution modes. Therefore it could happen that we first process + // ContractionOff, setting AllowContraction bit to 0, and then we process + // FPFastMathDefault enabling AllowContraction bit, effectively invalidating + // ContractionOff. Because of that, it's best to keep separate bits for the + // different execution modes, and we will try and combine them later when we + // emit OpExecutionMode instructions. + bool ContractionOff = false; + bool SignedZeroInfNanPreserve = false; + bool FPFastMathDefault = false; + + FPFastMathDefaultInfo() = default; + FPFastMathDefaultInfo(const Type *Ty, unsigned FastMathFlags) + : Ty(Ty), FastMathFlags(FastMathFlags) {} + bool operator==(const FPFastMathDefaultInfo &Other) const { + return Ty == Other.Ty && FastMathFlags == Other.FastMathFlags && + ContractionOff == Other.ContractionOff && + SignedZeroInfNanPreserve == Other.SignedZeroInfNanPreserve && + FPFastMathDefault == Other.FPFastMathDefault; + } +}; + +struct FPFastMathDefaultInfoVector + : public SmallVector<SPIRV::FPFastMathDefaultInfo, 3> { + static size_t computeFPFastMathDefaultInfoVecIndex(size_t BitWidth) { + switch (BitWidth) { + case 16: // half + return 0; + case 32: // float + return 1; + case 64: // double + return 2; + default: + report_fatal_error("Expected BitWidth to be 16, 32, 64", false); + } + llvm_unreachable( + "Unreachable code in computeFPFastMathDefaultInfoVecIndex"); + } +}; + +} // namespace SPIRV + // Add the given string as a series of integer operand, inserting null // terminators and padding to make sure the operands all have 32-bit // little-endian words. @@ -161,7 +209,7 @@ void buildOpMemberDecorate(Register Reg, MachineInstr &I, // Add an OpDecorate instruction by "spirv.Decorations" metadata node. void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, - const MDNode *GVarMD); + const MDNode *GVarMD, const SPIRVSubtarget &ST); // Return a valid position for the OpVariable instruction inside a function, // i.e., at the beginning of the first block of the function. @@ -508,6 +556,5 @@ unsigned getArrayComponentCount(const MachineRegisterInfo *MRI, const MachineInstr *ResType); MachineBasicBlock::iterator getFirstValidInstructionInsertPoint(MachineBasicBlock &BB); - } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index ab5c9c9..12fb46d 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1762,9 +1762,10 @@ public: GeneratedRTChecks(PredicatedScalarEvolution &PSE, DominatorTree *DT, LoopInfo *LI, TargetTransformInfo *TTI, const DataLayout &DL, TTI::TargetCostKind CostKind) - : DT(DT), LI(LI), TTI(TTI), SCEVExp(*PSE.getSE(), DL, "scev.check"), - MemCheckExp(*PSE.getSE(), DL, "scev.check"), PSE(PSE), - CostKind(CostKind) {} + : DT(DT), LI(LI), TTI(TTI), + SCEVExp(*PSE.getSE(), DL, "scev.check", /*PreserveLCSSA=*/false), + MemCheckExp(*PSE.getSE(), DL, "scev.check", /*PreserveLCSSA=*/false), + PSE(PSE), CostKind(CostKind) {} /// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can /// accurately estimate the cost of the runtime checks. The blocks are @@ -3902,8 +3903,7 @@ void LoopVectorizationPlanner::emitInvalidCostRemarks( if (VF.isScalar()) continue; - VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, - *CM.PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind); precomputeCosts(*Plan, VF, CostCtx); auto Iter = vp_depth_first_deep(Plan->getVectorLoopRegion()->getEntry()); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { @@ -4160,8 +4160,7 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() { // Add on other costs that are modelled in VPlan, but not in the legacy // cost model. - VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind, - *CM.PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind); VPRegionBlock *VectorRegion = P->getVectorLoopRegion(); assert(VectorRegion && "Expected to have a vector region!"); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( @@ -6836,7 +6835,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan, ElementCount VF) const { - VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind, *PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind); InstructionCost Cost = precomputeCosts(Plan, VF, CostCtx); // Now compute and add the VPlan-based cost. @@ -7069,8 +7068,7 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() { // simplifications not accounted for in the legacy cost model. If that's the // case, don't trigger the assertion, as the extra simplifications may cause a // different VF to be picked by the VPlan-based cost model. - VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind, - *CM.PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind); precomputeCosts(BestPlan, BestFactor.Width, CostCtx); // Verify that the VPlan-based and legacy cost models agree, except for VPlans // with early exits and plans with additional VPlan simplifications. The @@ -7486,12 +7484,13 @@ VPRecipeBuilder::tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands, VPSingleDefRecipe *VectorPtr; if (Reverse) { // When folding the tail, we may compute an address that we don't in the - // original scalar loop and it may not be inbounds. Drop Inbounds in that - // case. + // original scalar loop: drop the GEP no-wrap flags in this case. + // Otherwise preserve existing flags without no-unsigned-wrap, as we will + // emit negative indices. GEPNoWrapFlags Flags = - (CM.foldTailByMasking() || !GEP || !GEP->isInBounds()) + CM.foldTailByMasking() || !GEP ? GEPNoWrapFlags::none() - : GEPNoWrapFlags::inBounds(); + : GEP->getNoWrapFlags().withoutNoUnsignedWrap(); VectorPtr = new VPVectorEndPointerRecipe(Ptr, &Plan.getVF(), getLoadStoreType(I), /*Stride*/ -1, Flags, I->getDebugLoc()); @@ -8163,14 +8162,12 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, VFRange SubRange = {VF, MaxVFTimes2}; if (auto Plan = tryToBuildVPlanWithVPRecipes( std::unique_ptr<VPlan>(VPlan0->duplicate()), SubRange, &LVer)) { - bool HasScalarVF = Plan->hasScalarVFOnly(); // Now optimize the initial VPlan. - if (!HasScalarVF) - VPlanTransforms::runPass(VPlanTransforms::truncateToMinimalBitwidths, - *Plan, CM.getMinimalBitwidths()); + VPlanTransforms::runPass(VPlanTransforms::truncateToMinimalBitwidths, + *Plan, CM.getMinimalBitwidths()); VPlanTransforms::runPass(VPlanTransforms::optimize, *Plan); // TODO: try to put it close to addActiveLaneMask(). - if (CM.foldTailWithEVL() && !HasScalarVF) + if (CM.foldTailWithEVL()) VPlanTransforms::runPass(VPlanTransforms::addExplicitVectorLength, *Plan, CM.getMaxSafeElements()); assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid"); @@ -8600,8 +8597,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // TODO: Enable following transform when the EVL-version of extended-reduction // and mulacc-reduction are implemented. if (!CM.foldTailWithEVL()) { - VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, - *CM.PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind); VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan, CostCtx, Range); } @@ -10058,7 +10054,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { bool ForceVectorization = Hints.getForce() == LoopVectorizeHints::FK_Enabled; VPCostContext CostCtx(CM.TTI, *CM.TLI, LVP.getPlanFor(VF.Width), CM, - CM.CostKind, *CM.PSE.getSE()); + CM.CostKind); if (!ForceVectorization && !isOutsideLoopWorkProfitable(Checks, VF, L, PSE, CostCtx, LVP.getPlanFor(VF.Width), SEL, diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index c547662..f77d587 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -2105,6 +2105,7 @@ public: UserIgnoreList = nullptr; PostponedGathers.clear(); ValueToGatherNodes.clear(); + TreeEntryToStridedPtrInfoMap.clear(); } unsigned getTreeSize() const { return VectorizableTree.size(); } @@ -8948,6 +8949,8 @@ BoUpSLP::findExternalStoreUsersReorderIndices(TreeEntry *TE) const { void BoUpSLP::buildTree(ArrayRef<Value *> Roots, const SmallDenseSet<Value *> &UserIgnoreLst) { deleteTree(); + assert(TreeEntryToStridedPtrInfoMap.empty() && + "TreeEntryToStridedPtrInfoMap is not cleared"); UserIgnoreList = &UserIgnoreLst; if (!allSameType(Roots)) return; @@ -8956,6 +8959,8 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, void BoUpSLP::buildTree(ArrayRef<Value *> Roots) { deleteTree(); + assert(TreeEntryToStridedPtrInfoMap.empty() && + "TreeEntryToStridedPtrInfoMap is not cleared"); if (!allSameType(Roots)) return; buildTreeRec(Roots, 0, EdgeInfo()); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 728d291..81f1956 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -1750,8 +1750,7 @@ VPCostContext::getOperandInfo(VPValue *V) const { } InstructionCost VPCostContext::getScalarizationOverhead( - Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF, - bool AlwaysIncludeReplicatingR) { + Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF) { if (VF.isScalar()) return 0; @@ -1771,9 +1770,7 @@ InstructionCost VPCostContext::getScalarizationOverhead( SmallPtrSet<const VPValue *, 4> UniqueOperands; SmallVector<Type *> Tys; for (auto *Op : Operands) { - if (Op->isLiveIn() || - (!AlwaysIncludeReplicatingR && - isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op)) || + if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) || !UniqueOperands.insert(Op).second) continue; Tys.push_back(toVectorizedTy(Types.inferScalarType(Op), VF)); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 4c7a083..10d704d 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -3033,7 +3033,7 @@ public: assert(Red->getRecurrenceKind() == RecurKind::Add && "Expected an add reduction"); assert(getNumOperands() >= 3 && "Expected at least three operands"); - auto *SubConst = dyn_cast<ConstantInt>(getOperand(2)->getLiveInIRValue()); + [[maybe_unused]] auto *SubConst = dyn_cast<ConstantInt>(getOperand(2)->getLiveInIRValue()); assert(SubConst && SubConst->getValue() == 0 && Sub->getOpcode() == Instruction::Sub && "Expected a negating sub"); } diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h index 2a8baec..fe59774 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h @@ -349,14 +349,12 @@ struct VPCostContext { LoopVectorizationCostModel &CM; SmallPtrSet<Instruction *, 8> SkipCostComputation; TargetTransformInfo::TargetCostKind CostKind; - ScalarEvolution &SE; VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI, const VPlan &Plan, LoopVectorizationCostModel &CM, - TargetTransformInfo::TargetCostKind CostKind, - ScalarEvolution &SE) + TargetTransformInfo::TargetCostKind CostKind) : TTI(TTI), TLI(TLI), Types(Plan), LLVMCtx(Plan.getContext()), CM(CM), - CostKind(CostKind), SE(SE) {} + CostKind(CostKind) {} /// Return the cost for \p UI with \p VF using the legacy cost model as /// fallback until computing the cost of all recipes migrates to VPlan. @@ -376,12 +374,10 @@ struct VPCostContext { /// Estimate the overhead of scalarizing a recipe with result type \p ResultTy /// and \p Operands with \p VF. This is a convenience wrapper for the - /// type-based getScalarizationOverhead API. If \p AlwaysIncludeReplicatingR - /// is true, always compute the cost of scalarizing replicating operands. - InstructionCost - getScalarizationOverhead(Type *ResultTy, ArrayRef<const VPValue *> Operands, - ElementCount VF, - bool AlwaysIncludeReplicatingR = false); + /// type-based getScalarizationOverhead API. + InstructionCost getScalarizationOverhead(Type *ResultTy, + ArrayRef<const VPValue *> Operands, + ElementCount VF); }; /// This class can be used to assign names to VPValues. For VPValues without diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index ee03729..3a55710 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -3098,61 +3098,6 @@ bool VPReplicateRecipe::shouldPack() const { }); } -/// Returns true if \p Ptr is a pointer computation for which the legacy cost -/// model computes a SCEV expression when computing the address cost. -static bool shouldUseAddressAccessSCEV(const VPValue *Ptr) { - auto *PtrR = Ptr->getDefiningRecipe(); - if (!PtrR || !((isa<VPReplicateRecipe>(PtrR) && - cast<VPReplicateRecipe>(PtrR)->getOpcode() == - Instruction::GetElementPtr) || - isa<VPWidenGEPRecipe>(PtrR))) - return false; - - // We are looking for a GEP where all indices are either loop invariant or - // inductions. - for (VPValue *Opd : drop_begin(PtrR->operands())) { - if (!Opd->isDefinedOutsideLoopRegions() && - !isa<VPScalarIVStepsRecipe, VPWidenIntOrFpInductionRecipe>(Opd)) - return false; - } - - return true; -} - -/// Returns true if \p V is used as part of the address of another load or -/// store. -static bool isUsedByLoadStoreAddress(const VPUser *V) { - SmallPtrSet<const VPUser *, 4> Seen; - SmallVector<const VPUser *> WorkList = {V}; - - while (!WorkList.empty()) { - auto *Cur = dyn_cast<VPSingleDefRecipe>(WorkList.pop_back_val()); - if (!Cur || !Seen.insert(Cur).second) - continue; - - for (VPUser *U : Cur->users()) { - if (auto *InterleaveR = dyn_cast<VPInterleaveBase>(U)) - if (InterleaveR->getAddr() == Cur) - return true; - if (auto *RepR = dyn_cast<VPReplicateRecipe>(U)) { - if (RepR->getOpcode() == Instruction::Load && - RepR->getOperand(0) == Cur) - return true; - if (RepR->getOpcode() == Instruction::Store && - RepR->getOperand(1) == Cur) - return true; - } - if (auto *MemR = dyn_cast<VPWidenMemoryRecipe>(U)) { - if (MemR->getAddr() == Cur && MemR->isConsecutive()) - return true; - } - } - - append_range(WorkList, cast<VPSingleDefRecipe>(Cur)->users()); - } - return false; -} - InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, VPCostContext &Ctx) const { Instruction *UI = cast<Instruction>(getUnderlyingValue()); @@ -3260,58 +3205,21 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, } case Instruction::Load: case Instruction::Store: { - if (VF.isScalable() && !isSingleScalar()) - return InstructionCost::getInvalid(); - + if (isSingleScalar()) { + bool IsLoad = UI->getOpcode() == Instruction::Load; + Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0)); + Type *ScalarPtrTy = Ctx.Types.inferScalarType(getOperand(IsLoad ? 0 : 1)); + const Align Alignment = getLoadStoreAlignment(UI); + unsigned AS = getLoadStoreAddressSpace(UI); + TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(UI->getOperand(0)); + InstructionCost ScalarMemOpCost = Ctx.TTI.getMemoryOpCost( + UI->getOpcode(), ValTy, Alignment, AS, Ctx.CostKind, OpInfo, UI); + return ScalarMemOpCost + Ctx.TTI.getAddressComputationCost( + ScalarPtrTy, nullptr, nullptr, Ctx.CostKind); + } // TODO: See getMemInstScalarizationCost for how to handle replicating and // predicated cases. - const VPRegionBlock *ParentRegion = getParent()->getParent(); - if (ParentRegion && ParentRegion->isReplicator()) - break; - - bool IsLoad = UI->getOpcode() == Instruction::Load; - const VPValue *PtrOp = getOperand(!IsLoad); - // TODO: Handle cases where we need to pass a SCEV to - // getAddressComputationCost. - if (shouldUseAddressAccessSCEV(PtrOp)) - break; - - Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0)); - Type *ScalarPtrTy = Ctx.Types.inferScalarType(PtrOp); - const Align Alignment = getLoadStoreAlignment(UI); - unsigned AS = getLoadStoreAddressSpace(UI); - TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(UI->getOperand(0)); - InstructionCost ScalarMemOpCost = Ctx.TTI.getMemoryOpCost( - UI->getOpcode(), ValTy, Alignment, AS, Ctx.CostKind, OpInfo); - - Type *PtrTy = isSingleScalar() ? ScalarPtrTy : toVectorTy(ScalarPtrTy, VF); - - InstructionCost ScalarCost = - ScalarMemOpCost + Ctx.TTI.getAddressComputationCost( - PtrTy, &Ctx.SE, nullptr, Ctx.CostKind); - if (isSingleScalar()) - return ScalarCost; - - SmallVector<const VPValue *> OpsToScalarize; - Type *ResultTy = Type::getVoidTy(PtrTy->getContext()); - // Set ResultTy and OpsToScalarize, if scalarization is needed. Currently we - // don't assign scalarization overhead in general, if the target prefers - // vectorized addressing or the loaded value is used as part of an address - // of another load or store. - bool PreferVectorizedAddressing = Ctx.TTI.prefersVectorizedAddressing(); - if (PreferVectorizedAddressing || !isUsedByLoadStoreAddress(this)) { - bool EfficientVectorLoadStore = - Ctx.TTI.supportsEfficientVectorElementLoadStore(); - if (!(IsLoad && !PreferVectorizedAddressing) && - !(!IsLoad && EfficientVectorLoadStore)) - append_range(OpsToScalarize, operands()); - - if (!EfficientVectorLoadStore) - ResultTy = Ctx.Types.inferScalarType(this); - } - - return (ScalarCost * VF.getFixedValue()) + - Ctx.getScalarizationOverhead(ResultTy, OpsToScalarize, VF, true); + break; } } diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 969dce4..a73b083 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -2124,6 +2124,8 @@ static void licm(VPlan &Plan) { void VPlanTransforms::truncateToMinimalBitwidths( VPlan &Plan, const MapVector<Instruction *, uint64_t> &MinBWs) { + if (Plan.hasScalarVFOnly()) + return; // Keep track of created truncates, so they can be re-used. Note that we // cannot use RAUW after creating a new truncate, as this would could make // other uses have different types for their operands, making them invalidly @@ -2704,6 +2706,8 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { /// void VPlanTransforms::addExplicitVectorLength( VPlan &Plan, const std::optional<unsigned> &MaxSafeElements) { + if (Plan.hasScalarVFOnly()) + return; VPBasicBlock *Header = Plan.getVectorLoopRegion()->getEntryBasicBlock(); auto *CanonicalIVPHI = Plan.getCanonicalIV(); |