diff options
Diffstat (limited to 'llvm/lib/Target')
75 files changed, 791 insertions, 1512 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index fede586..47c1ac4 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1032,6 +1032,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, } break; } + case Intrinsic::experimental_vector_extract_last_active: + if (ST->isSVEorStreamingSVEAvailable()) { + auto [LegalCost, _] = getTypeLegalizationCost(ICA.getArgTypes()[0]); + // This should turn into chained clastb instructions. + return LegalCost; + } + break; default: break; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp index 9ce1224..0c97741 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -221,12 +221,21 @@ bool AMDGPUInstructionSelector::selectCOPY(MachineInstr &I) const { bool AMDGPUInstructionSelector::selectCOPY_SCC_VCC(MachineInstr &I) const { const DebugLoc &DL = I.getDebugLoc(); MachineBasicBlock *BB = I.getParent(); + Register VCCReg = I.getOperand(1).getReg(); + MachineInstr *Cmp; + + // Set SCC as a side effect with S_CMP or S_OR. + if (STI.hasScalarCompareEq64()) { + unsigned CmpOpc = + STI.isWave64() ? AMDGPU::S_CMP_LG_U64 : AMDGPU::S_CMP_LG_U32; + Cmp = BuildMI(*BB, &I, DL, TII.get(CmpOpc)).addReg(VCCReg).addImm(0); + } else { + Register DeadDst = MRI->createVirtualRegister(&AMDGPU::SReg_64RegClass); + Cmp = BuildMI(*BB, &I, DL, TII.get(AMDGPU::S_OR_B64), DeadDst) + .addReg(VCCReg) + .addReg(VCCReg); + } - unsigned CmpOpc = - STI.isWave64() ? AMDGPU::S_CMP_LG_U64 : AMDGPU::S_CMP_LG_U32; - MachineInstr *Cmp = BuildMI(*BB, &I, DL, TII.get(CmpOpc)) - .addReg(I.getOperand(1).getReg()) - .addImm(0); if (!constrainSelectedInstRegOperands(*Cmp, TII, TRI, RBI)) return false; diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerVGPREncoding.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerVGPREncoding.cpp index 1e6589e..d7d0292 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULowerVGPREncoding.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULowerVGPREncoding.cpp @@ -58,6 +58,8 @@ class AMDGPULowerVGPREncoding { static constexpr unsigned BitsPerField = 2; static constexpr unsigned NumFields = 4; static constexpr unsigned FieldMask = (1 << BitsPerField) - 1; + static constexpr unsigned ModeWidth = NumFields * BitsPerField; + static constexpr unsigned ModeMask = (1 << ModeWidth) - 1; using ModeType = PackedVector<unsigned, BitsPerField, std::bitset<BitsPerField * NumFields>>; @@ -82,12 +84,12 @@ private: const SIInstrInfo *TII; const SIRegisterInfo *TRI; + // Current basic block. + MachineBasicBlock *MBB; + /// Most recent s_set_* instruction. MachineInstr *MostRecentModeSet; - /// Whether the current mode is known. - bool CurrentModeKnown; - /// Current mode bits. ModeTy CurrentMode; @@ -108,10 +110,13 @@ private: MachineInstr *Clause; /// Insert mode change before \p I. \returns true if mode was changed. - bool setMode(ModeTy NewMode, ModeTy Mask, MachineInstr *I); + bool setMode(ModeTy NewMode, ModeTy Mask, + MachineBasicBlock::instr_iterator I); /// Reset mode to default. - void resetMode(MachineInstr *I) { setMode(ModeTy(), ModeTy::fullMask(), I); } + void resetMode(MachineBasicBlock::instr_iterator I) { + setMode(ModeTy(), ModeTy::fullMask(), I); + } /// If \p MO references VGPRs, return the MSBs. Otherwise, return nullopt. std::optional<unsigned> getMSBs(const MachineOperand &MO) const; @@ -130,38 +135,43 @@ private: /// Check if an instruction \p I is within a clause and returns a suitable /// iterator to insert mode change. It may also modify the S_CLAUSE /// instruction to extend it or drop the clause if it cannot be adjusted. - MachineInstr *handleClause(MachineInstr *I); + MachineBasicBlock::instr_iterator + handleClause(MachineBasicBlock::instr_iterator I); }; bool AMDGPULowerVGPREncoding::setMode(ModeTy NewMode, ModeTy Mask, - MachineInstr *I) { + MachineBasicBlock::instr_iterator I) { assert((NewMode.raw_bits() & ~Mask.raw_bits()).none()); - if (CurrentModeKnown) { - auto Delta = NewMode.raw_bits() ^ CurrentMode.raw_bits(); + auto Delta = NewMode.raw_bits() ^ CurrentMode.raw_bits(); - if ((Delta & Mask.raw_bits()).none()) { - CurrentMask |= Mask; - return false; - } + if ((Delta & Mask.raw_bits()).none()) { + CurrentMask |= Mask; + return false; + } - if (MostRecentModeSet && (Delta & CurrentMask.raw_bits()).none()) { - CurrentMode |= NewMode; - CurrentMask |= Mask; + if (MostRecentModeSet && (Delta & CurrentMask.raw_bits()).none()) { + CurrentMode |= NewMode; + CurrentMask |= Mask; - MostRecentModeSet->getOperand(0).setImm(CurrentMode); - return true; - } + MachineOperand &Op = MostRecentModeSet->getOperand(0); + + // Carry old mode bits from the existing instruction. + int64_t OldModeBits = Op.getImm() & (ModeMask << ModeWidth); + + Op.setImm(CurrentMode | OldModeBits); + return true; } + // Record previous mode into high 8 bits of the immediate. + int64_t OldModeBits = CurrentMode << ModeWidth; + I = handleClause(I); - MostRecentModeSet = - BuildMI(*I->getParent(), I, {}, TII->get(AMDGPU::S_SET_VGPR_MSB)) - .addImm(NewMode); + MostRecentModeSet = BuildMI(*MBB, I, {}, TII->get(AMDGPU::S_SET_VGPR_MSB)) + .addImm(NewMode | OldModeBits); CurrentMode = NewMode; CurrentMask = Mask; - CurrentModeKnown = true; return true; } @@ -233,21 +243,22 @@ bool AMDGPULowerVGPREncoding::runOnMachineInstr(MachineInstr &MI) { if (Ops.first) { ModeTy NewMode, Mask; computeMode(NewMode, Mask, MI, Ops.first, Ops.second); - return setMode(NewMode, Mask, &MI); + return setMode(NewMode, Mask, MI.getIterator()); } assert(!TII->hasVGPRUses(MI) || MI.isMetaInstruction() || MI.isPseudo()); return false; } -MachineInstr *AMDGPULowerVGPREncoding::handleClause(MachineInstr *I) { +MachineBasicBlock::instr_iterator +AMDGPULowerVGPREncoding::handleClause(MachineBasicBlock::instr_iterator I) { if (!ClauseRemaining) return I; // A clause cannot start with a special instruction, place it right before // the clause. if (ClauseRemaining == ClauseLen) { - I = Clause->getPrevNode(); + I = Clause->getPrevNode()->getIterator(); assert(I->isBundle()); return I; } @@ -284,9 +295,9 @@ bool AMDGPULowerVGPREncoding::run(MachineFunction &MF) { ClauseLen = ClauseRemaining = 0; CurrentMode.reset(); CurrentMask.reset(); - CurrentModeKnown = true; for (auto &MBB : MF) { MostRecentModeSet = nullptr; + this->MBB = &MBB; for (auto &MI : llvm::make_early_inc_range(MBB.instrs())) { if (MI.isMetaInstruction()) @@ -294,17 +305,16 @@ bool AMDGPULowerVGPREncoding::run(MachineFunction &MF) { if (MI.isTerminator() || MI.isCall()) { if (MI.getOpcode() == AMDGPU::S_ENDPGM || - MI.getOpcode() == AMDGPU::S_ENDPGM_SAVED) { + MI.getOpcode() == AMDGPU::S_ENDPGM_SAVED) CurrentMode.reset(); - CurrentModeKnown = true; - } else - resetMode(&MI); + else + resetMode(MI.getIterator()); continue; } if (MI.isInlineAsm()) { if (TII->hasVGPRUses(MI)) - resetMode(&MI); + resetMode(MI.getIterator()); continue; } @@ -323,14 +333,8 @@ bool AMDGPULowerVGPREncoding::run(MachineFunction &MF) { --ClauseRemaining; } - // If we're falling through to a block that has at least one other - // predecessor, we no longer know the mode. - MachineBasicBlock *Next = MBB.getNextNode(); - if (Next && Next->pred_size() >= 2 && - llvm::is_contained(Next->predecessors(), &MBB)) { - if (CurrentMode.raw_bits().any()) - CurrentModeKnown = false; - } + // Reset the mode if we are falling through. + resetMode(MBB.instr_end()); } return Changed; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMCInstLower.cpp b/llvm/lib/Target/AMDGPU/AMDGPUMCInstLower.cpp index 680e7eb..844649ebb 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUMCInstLower.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUMCInstLower.cpp @@ -412,7 +412,7 @@ void AMDGPUAsmPrinter::emitInstruction(const MachineInstr *MI) { *OutStreamer); if (isVerbose() && MI->getOpcode() == AMDGPU::S_SET_VGPR_MSB) { - unsigned V = MI->getOperand(0).getImm(); + unsigned V = MI->getOperand(0).getImm() & 0xff; OutStreamer->AddComment( " msbs: dst=" + Twine(V >> 6) + " src0=" + Twine(V & 3) + " src1=" + Twine((V >> 2) & 3) + " src2=" + Twine((V >> 4) & 3)); diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp index e187959..907f830 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp @@ -24,6 +24,7 @@ #include "llvm/CodeGen/GlobalISel/CSEInfo.h" #include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h" #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" +#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" #include "llvm/CodeGen/GlobalISel/Utils.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineUniformityAnalysis.h" @@ -34,9 +35,17 @@ using namespace llvm; using namespace AMDGPU; +using namespace llvm::MIPatternMatch; namespace { +// AMDGPU-specific pattern matchers +template <typename SrcTy> +inline UnaryOp_match<SrcTy, AMDGPU::G_AMDGPU_READANYLANE> +m_GAMDGPUReadAnyLane(const SrcTy &Src) { + return UnaryOp_match<SrcTy, AMDGPU::G_AMDGPU_READANYLANE>(Src); +} + class AMDGPURegBankLegalize : public MachineFunctionPass { public: static char ID; @@ -160,10 +169,18 @@ AMDGPURegBankLegalizeCombiner::tryMatchRALFromUnmerge(Register Src) { Register AMDGPURegBankLegalizeCombiner::getReadAnyLaneSrc(Register Src) { // Src = G_AMDGPU_READANYLANE RALSrc - auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_AMDGPU_READANYLANE); - if (RAL) + Register RALSrc; + if (mi_match(Src, MRI, m_GAMDGPUReadAnyLane(m_Reg(RALSrc)))) return RALSrc; + // TruncSrc = G_AMDGPU_READANYLANE RALSrc + // AextSrc = G_TRUNC TruncSrc + // Src = G_ANYEXT AextSrc + if (mi_match(Src, MRI, + m_GAnyExt(m_GTrunc(m_GAMDGPUReadAnyLane(m_Reg(RALSrc)))))) { + return RALSrc; + } + // LoVgpr, HiVgpr = G_UNMERGE_VALUES UnmergeSrc // LoSgpr = G_AMDGPU_READANYLANE LoVgpr // HiSgpr = G_AMDGPU_READANYLANE HiVgpr diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp index 5407566..dc8fa7f 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp @@ -500,6 +500,16 @@ void RegBankLegalizeHelper::lowerUnpackMinMax(MachineInstr &MI) { MI.eraseFromParent(); } +void RegBankLegalizeHelper::lowerUnpackAExt(MachineInstr &MI) { + auto [Op1Lo, Op1Hi] = unpackAExt(MI.getOperand(1).getReg()); + auto [Op2Lo, Op2Hi] = unpackAExt(MI.getOperand(2).getReg()); + auto ResLo = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Op1Lo, Op2Lo}); + auto ResHi = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Op1Hi, Op2Hi}); + B.buildBuildVectorTrunc(MI.getOperand(0).getReg(), + {ResLo.getReg(0), ResHi.getReg(0)}); + MI.eraseFromParent(); +} + static bool isSignedBFE(MachineInstr &MI) { if (GIntrinsic *GI = dyn_cast<GIntrinsic>(&MI)) return (GI->is(Intrinsic::amdgcn_sbfe)); @@ -616,6 +626,23 @@ void RegBankLegalizeHelper::lowerSplitTo32(MachineInstr &MI) { MI.eraseFromParent(); } +void RegBankLegalizeHelper::lowerSplitTo16(MachineInstr &MI) { + Register Dst = MI.getOperand(0).getReg(); + assert(MRI.getType(Dst) == V2S16); + auto [Op1Lo32, Op1Hi32] = unpackAExt(MI.getOperand(1).getReg()); + auto [Op2Lo32, Op2Hi32] = unpackAExt(MI.getOperand(2).getReg()); + unsigned Opc = MI.getOpcode(); + auto Flags = MI.getFlags(); + auto Op1Lo = B.buildTrunc(SgprRB_S16, Op1Lo32); + auto Op1Hi = B.buildTrunc(SgprRB_S16, Op1Hi32); + auto Op2Lo = B.buildTrunc(SgprRB_S16, Op2Lo32); + auto Op2Hi = B.buildTrunc(SgprRB_S16, Op2Hi32); + auto Lo = B.buildInstr(Opc, {SgprRB_S16}, {Op1Lo, Op2Lo}, Flags); + auto Hi = B.buildInstr(Opc, {SgprRB_S16}, {Op1Hi, Op2Hi}, Flags); + B.buildMergeLikeInstr(Dst, {Lo, Hi}); + MI.eraseFromParent(); +} + void RegBankLegalizeHelper::lowerSplitTo32Select(MachineInstr &MI) { Register Dst = MI.getOperand(0).getReg(); LLT DstTy = MRI.getType(Dst); @@ -688,6 +715,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI, return lowerUnpackBitShift(MI); case UnpackMinMax: return lowerUnpackMinMax(MI); + case ScalarizeToS16: + return lowerSplitTo16(MI); case Ext32To64: { const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg()); MachineInstrBuilder Hi; @@ -804,6 +833,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI, } break; } + case UnpackAExt: + return lowerUnpackAExt(MI); case WidenMMOToS32: return widenMMOToS32(cast<GAnyLoad>(MI)); } @@ -837,6 +868,7 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) { return LLT::scalar(32); case Sgpr64: case Vgpr64: + case UniInVgprS64: return LLT::scalar(64); case Sgpr128: case Vgpr128: @@ -960,6 +992,7 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) { case UniInVcc: case UniInVgprS16: case UniInVgprS32: + case UniInVgprS64: case UniInVgprV2S16: case UniInVgprV4S32: case UniInVgprB32: @@ -1092,6 +1125,7 @@ void RegBankLegalizeHelper::applyMappingDst( break; } case UniInVgprS32: + case UniInVgprS64: case UniInVgprV2S16: case UniInVgprV4S32: { assert(Ty == getTyFromID(MethodIDs[OpIdx])); @@ -1120,7 +1154,8 @@ void RegBankLegalizeHelper::applyMappingDst( assert(RB == SgprRB); Register NewDst = MRI.createVirtualRegister(SgprRB_S32); Op.setReg(NewDst); - B.buildTrunc(Reg, NewDst); + if (!MRI.use_empty(Reg)) + B.buildTrunc(Reg, NewDst); break; } case InvalidMapping: { diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h index d937815..e7598f8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h @@ -72,6 +72,7 @@ class RegBankLegalizeHelper { static constexpr LLT P6 = LLT::pointer(6, 32); MachineRegisterInfo::VRegAttrs SgprRB_S32 = {SgprRB, S32}; + MachineRegisterInfo::VRegAttrs SgprRB_S16 = {SgprRB, S16}; MachineRegisterInfo::VRegAttrs VgprRB_S32 = {VgprRB, S32}; MachineRegisterInfo::VRegAttrs VccRB_S1 = {VccRB, S1}; @@ -121,9 +122,11 @@ private: void lowerV_BFE(MachineInstr &MI); void lowerS_BFE(MachineInstr &MI); void lowerSplitTo32(MachineInstr &MI); + void lowerSplitTo16(MachineInstr &MI); void lowerSplitTo32Select(MachineInstr &MI); void lowerSplitTo32SExtInReg(MachineInstr &MI); void lowerUnpackMinMax(MachineInstr &MI); + void lowerUnpackAExt(MachineInstr &MI); }; } // end namespace AMDGPU diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp index a67b12a..103cdec 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp @@ -470,7 +470,19 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST, .Uni(S16, {{Sgpr32Trunc}, {Sgpr32AExt, Sgpr32AExt}}) .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}}) .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}}) - .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}); + .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}) + .Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, UnpackAExt}) + .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}}) + .Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr64}}) + .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr64}}); + + addRulesForGOpcs({G_UADDO, G_USUBO}, Standard) + .Uni(S32, {{Sgpr32, Sgpr32Trunc}, {Sgpr32, Sgpr32}}) + .Div(S32, {{Vgpr32, Vcc}, {Vgpr32, Vgpr32}}); + + addRulesForGOpcs({G_UADDE, G_USUBE}, Standard) + .Uni(S32, {{Sgpr32, Sgpr32Trunc}, {Sgpr32, Sgpr32, Sgpr32AExtBoolInReg}}) + .Div(S32, {{Vgpr32, Vcc}, {Vgpr32, Vgpr32, Vcc}}); addRulesForGOpcs({G_MUL}, Standard).Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}); @@ -901,14 +913,26 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST, addRulesForGOpcs({G_ABS}, Standard).Uni(S16, {{Sgpr32Trunc}, {Sgpr32SExt}}); - addRulesForGOpcs({G_READSTEADYCOUNTER}, Standard).Uni(S64, {{Sgpr64}, {}}); + addRulesForGOpcs({G_READSTEADYCOUNTER, G_READCYCLECOUNTER}, Standard) + .Uni(S64, {{Sgpr64}, {}}); bool hasSALUFloat = ST->hasSALUFloatInsts(); addRulesForGOpcs({G_FADD}, Standard) + .Uni(S16, {{UniInVgprS16}, {Vgpr16, Vgpr16}}, !hasSALUFloat) + .Uni(S16, {{Sgpr16}, {Sgpr16, Sgpr16}}, hasSALUFloat) + .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}}) .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}}, hasSALUFloat) .Uni(S32, {{UniInVgprS32}, {Vgpr32, Vgpr32}}, !hasSALUFloat) - .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}); + .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}) + .Uni(S64, {{UniInVgprS64}, {Vgpr64, Vgpr64}}) + .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr64}}) + .Uni(V2S16, {{UniInVgprV2S16}, {VgprV2S16, VgprV2S16}}, !hasSALUFloat) + .Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, ScalarizeToS16}, + hasSALUFloat) + .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}}) + .Any({{UniV2S32}, {{UniInVgprV2S32}, {VgprV2S32, VgprV2S32}}}) + .Any({{DivV2S32}, {{VgprV2S32}, {VgprV2S32, VgprV2S32}}}); addRulesForGOpcs({G_FPTOUI}) .Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}}, hasSALUFloat) diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h index 93e0efd..e6df5d8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h @@ -92,8 +92,10 @@ enum UniformityLLTOpPredicateID { V4S32, UniV2S16, + UniV2S32, DivV2S16, + DivV2S32, // B types B32, @@ -178,7 +180,9 @@ enum RegBankLLTMappingApplyID { UniInVcc, UniInVgprS16, UniInVgprS32, + UniInVgprS64, UniInVgprV2S16, + UniInVgprV2S32, UniInVgprV4S32, UniInVgprB32, UniInVgprB64, @@ -217,13 +221,15 @@ enum LoweringMethodID { V_BFE, VgprToVccCopy, SplitTo32, + ScalarizeToS16, SplitTo32Select, SplitTo32SExtInReg, Ext32To64, UniCstExt, SplitLoad, WidenLoad, - WidenMMOToS32 + WidenMMOToS32, + UnpackAExt }; enum FastRulesTypes { diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCTargetDesc.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCTargetDesc.cpp index 013cfeb..28b4da8 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCTargetDesc.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCTargetDesc.cpp @@ -168,7 +168,7 @@ bool AMDGPUMCInstrAnalysis::evaluateBranch(const MCInst &Inst, uint64_t Addr, void AMDGPUMCInstrAnalysis::updateState(const MCInst &Inst, uint64_t Addr) { if (Inst.getOpcode() == AMDGPU::S_SET_VGPR_MSB_gfx12) - VgprMSBs = Inst.getOperand(0).getImm(); + VgprMSBs = Inst.getOperand(0).getImm() & 0xff; else if (isTerminator(Inst)) VgprMSBs = 0; } diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index b34ab2a..8bb2808 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -7035,9 +7035,15 @@ static SDValue lowerBALLOTIntrinsic(const SITargetLowering &TLI, SDNode *N, SDLoc SL(N); if (Src.getOpcode() == ISD::SETCC) { + SDValue Op0 = Src.getOperand(0); + SDValue Op1 = Src.getOperand(1); + // Need to expand bfloat to float for comparison (setcc). + if (Op0.getValueType() == MVT::bf16) { + Op0 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Op0); + Op1 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Op1); + } // (ballot (ISD::SETCC ...)) -> (AMDGPUISD::SETCC ...) - return DAG.getNode(AMDGPUISD::SETCC, SL, VT, Src.getOperand(0), - Src.getOperand(1), Src.getOperand(2)); + return DAG.getNode(AMDGPUISD::SETCC, SL, VT, Op0, Op1, Src.getOperand(2)); } if (const ConstantSDNode *Arg = dyn_cast<ConstantSDNode>(Src)) { // (ballot 0) -> 0 diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index a4d3d62..6b06534 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -22109,6 +22109,11 @@ bool ARMTargetLowering::isComplexDeinterleavingOperationSupported( ScalarTy->isIntegerTy(32)); } +ArrayRef<MCPhysReg> ARMTargetLowering::getRoundingControlRegisters() const { + static const MCPhysReg RCRegs[] = {ARM::FPSCR_RM}; + return RCRegs; +} + Value *ARMTargetLowering::createComplexDeinterleavingIR( IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h index 357d2c5..bf3438b 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -1009,6 +1009,8 @@ class VectorType; bool isUnsupportedFloatingType(EVT VT) const; + ArrayRef<MCPhysReg> getRoundingControlRegisters() const override; + SDValue getCMOV(const SDLoc &dl, EVT VT, SDValue FalseVal, SDValue TrueVal, SDValue ARMcc, SDValue Flags, SelectionDAG &DAG) const; SDValue getARMCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, diff --git a/llvm/lib/Target/ARM/ARMSelectionDAGInfo.cpp b/llvm/lib/Target/ARM/ARMSelectionDAGInfo.cpp index ebfa593..bf7c962f 100644 --- a/llvm/lib/Target/ARM/ARMSelectionDAGInfo.cpp +++ b/llvm/lib/Target/ARM/ARMSelectionDAGInfo.cpp @@ -47,9 +47,7 @@ SDValue ARMSelectionDAGInfo::EmitSpecializedLibcall( // Only use a specialized AEABI function if the default version of this // Libcall is an AEABI function. - if (std::strncmp(TLI->getLibcallName(LC), "__aeabi", 7) != 0) - return SDValue(); - + // // Translate RTLIB::Libcall to AEABILibcall. We only do this in order to be // able to translate memset to memclr and use the value to index the function // name array. @@ -61,12 +59,21 @@ SDValue ARMSelectionDAGInfo::EmitSpecializedLibcall( } AEABILibcall; switch (LC) { case RTLIB::MEMCPY: + if (TLI->getLibcallImpl(LC) != RTLIB::impl___aeabi_memcpy) + return SDValue(); + AEABILibcall = AEABI_MEMCPY; break; case RTLIB::MEMMOVE: + if (TLI->getLibcallImpl(LC) != RTLIB::impl___aeabi_memmove) + return SDValue(); + AEABILibcall = AEABI_MEMMOVE; break; case RTLIB::MEMSET: + if (TLI->getLibcallImpl(LC) != RTLIB::impl___aeabi_memset) + return SDValue(); + AEABILibcall = AEABI_MEMSET; if (isNullConstant(Src)) AEABILibcall = AEABI_MEMCLR; diff --git a/llvm/lib/Target/ARM/AsmParser/ARMAsmParser.cpp b/llvm/lib/Target/ARM/AsmParser/ARMAsmParser.cpp index f60660b..1bb670d 100644 --- a/llvm/lib/Target/ARM/AsmParser/ARMAsmParser.cpp +++ b/llvm/lib/Target/ARM/AsmParser/ARMAsmParser.cpp @@ -426,15 +426,15 @@ class ARMAsmParser : public MCTargetAsmParser { VPTState.CurPosition = ~0U; } - void Note(SMLoc L, const Twine &Msg, SMRange Range = std::nullopt) { + void Note(SMLoc L, const Twine &Msg, SMRange Range = {}) { return getParser().Note(L, Msg, Range); } - bool Warning(SMLoc L, const Twine &Msg, SMRange Range = std::nullopt) { + bool Warning(SMLoc L, const Twine &Msg, SMRange Range = {}) { return getParser().Warning(L, Msg, Range); } - bool Error(SMLoc L, const Twine &Msg, SMRange Range = std::nullopt) { + bool Error(SMLoc L, const Twine &Msg, SMRange Range = {}) { return getParser().Error(L, Msg, Range); } diff --git a/llvm/lib/Target/DirectX/DirectXAsmPrinter.cpp b/llvm/lib/Target/DirectX/DirectXAsmPrinter.cpp index 15def36..b6bbb20 100644 --- a/llvm/lib/Target/DirectX/DirectXAsmPrinter.cpp +++ b/llvm/lib/Target/DirectX/DirectXAsmPrinter.cpp @@ -52,6 +52,7 @@ void DXILAsmPrinter::emitGlobalVariable(const GlobalVariable *GV) { emitGlobalConstant(GV->getDataLayout(), GV->getInitializer()); } -extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXAsmPrinter() { +extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void +LLVMInitializeDirectXAsmPrinter() { RegisterAsmPrinter<DXILAsmPrinter> X(getTheDirectXTarget()); } diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index bcf8440..84b1a31 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -53,7 +53,8 @@ using namespace llvm; -extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() { +extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void +LLVMInitializeDirectXTarget() { RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget()); auto *PR = PassRegistry::getPassRegistry(); initializeDXILIntrinsicExpansionLegacyPass(*PR); diff --git a/llvm/lib/Target/DirectX/MCTargetDesc/DirectXMCTargetDesc.cpp b/llvm/lib/Target/DirectX/MCTargetDesc/DirectXMCTargetDesc.cpp index 9a14c01..62ad014 100644 --- a/llvm/lib/Target/DirectX/MCTargetDesc/DirectXMCTargetDesc.cpp +++ b/llvm/lib/Target/DirectX/MCTargetDesc/DirectXMCTargetDesc.cpp @@ -132,7 +132,8 @@ static MCRegisterInfo *createDirectXMCRegisterInfo(const Triple &Triple) { static MCInstrInfo *createDirectXMCInstrInfo() { return new MCInstrInfo(); } -extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTargetMC() { +extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void +LLVMInitializeDirectXTargetMC() { Target &T = getTheDirectXTarget(); RegisterMCAsmInfo<DirectXMCAsmInfo> X(T); TargetRegistry::RegisterMCInstrInfo(T, createDirectXMCInstrInfo); diff --git a/llvm/lib/Target/DirectX/TargetInfo/DirectXTargetInfo.cpp b/llvm/lib/Target/DirectX/TargetInfo/DirectXTargetInfo.cpp index ae01626..934bd1b 100644 --- a/llvm/lib/Target/DirectX/TargetInfo/DirectXTargetInfo.cpp +++ b/llvm/lib/Target/DirectX/TargetInfo/DirectXTargetInfo.cpp @@ -24,7 +24,8 @@ Target &getTheDirectXTarget() { using namespace llvm; -extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTargetInfo() { +extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void +LLVMInitializeDirectXTargetInfo() { RegisterTarget<Triple::dxil, /*HasJIT=*/false> X( getTheDirectXTarget(), "dxil", "DirectX Intermediate Language", "DXIL"); } diff --git a/llvm/lib/Target/Hexagon/HexagonCopyHoisting.cpp b/llvm/lib/Target/Hexagon/HexagonCopyHoisting.cpp index 3b810d0..79863e1 100644 --- a/llvm/lib/Target/Hexagon/HexagonCopyHoisting.cpp +++ b/llvm/lib/Target/Hexagon/HexagonCopyHoisting.cpp @@ -34,7 +34,7 @@ class HexagonCopyHoisting : public MachineFunctionPass { public: static char ID; - HexagonCopyHoisting() : MachineFunctionPass(ID), MFN(nullptr), MRI(nullptr) {} + HexagonCopyHoisting() : MachineFunctionPass(ID) {} StringRef getPassName() const override { return "Hexagon Copy Hoisting"; } @@ -56,8 +56,8 @@ public: void moveCopyInstr(MachineBasicBlock *DestBB, std::pair<Register, Register> Key, MachineInstr *MI); - MachineFunction *MFN; - MachineRegisterInfo *MRI; + MachineFunction *MFN = nullptr; + MachineRegisterInfo *MRI = nullptr; std::vector<DenseMap<std::pair<Register, Register>, MachineInstr *>> CopyMIList; }; diff --git a/llvm/lib/Target/Hexagon/HexagonGenMemAbsolute.cpp b/llvm/lib/Target/Hexagon/HexagonGenMemAbsolute.cpp index 93418f7..a10c937 100644 --- a/llvm/lib/Target/Hexagon/HexagonGenMemAbsolute.cpp +++ b/llvm/lib/Target/Hexagon/HexagonGenMemAbsolute.cpp @@ -34,13 +34,13 @@ STATISTIC(HexagonNumStoreAbsConversions, namespace { class HexagonGenMemAbsolute : public MachineFunctionPass { - const HexagonInstrInfo *TII; - MachineRegisterInfo *MRI; - const TargetRegisterInfo *TRI; + const HexagonInstrInfo *TII = nullptr; + MachineRegisterInfo *MRI = nullptr; + const TargetRegisterInfo *TRI = nullptr; public: static char ID; - HexagonGenMemAbsolute() : MachineFunctionPass(ID), TII(0), MRI(0), TRI(0) {} + HexagonGenMemAbsolute() : MachineFunctionPass(ID) {} StringRef getPassName() const override { return "Hexagon Generate Load/Store Set Absolute Address Instruction"; diff --git a/llvm/lib/Target/Hexagon/HexagonPatterns.td b/llvm/lib/Target/Hexagon/HexagonPatterns.td index 85ce944..e40dbd2 100644 --- a/llvm/lib/Target/Hexagon/HexagonPatterns.td +++ b/llvm/lib/Target/Hexagon/HexagonPatterns.td @@ -3434,6 +3434,19 @@ let AddedComplexity = 100 in { (C2_not (S4_stored_locked I32:$Rs, I64:$Rt))>; } +multiclass FloatClass<SDPatternOperator IntOp, InstHexagon MI, + PatFrag RegPred> { + let AddedComplexity = 100 in { + def: Pat<(i1 (seteq (IntOp RegPred:$Rs, u5_0ImmPred_timm:$u5), 0)), + (C2_not (MI RegPred:$Rs, u5_0ImmPred_timm:$u5))>; + def: Pat<(i1 (setne (IntOp RegPred:$Rs, u5_0ImmPred_timm:$u5), 0)), + (MI RegPred:$Rs, u5_0ImmPred_timm:$u5)>; + } +} + +defm : FloatClass<int_hexagon_F2_sfclass, F2_sfclass, F32>; +defm : FloatClass<int_hexagon_F2_dfclass, F2_dfclass, F64>; + def: Pat<(int_hexagon_instrprof_custom (HexagonAtPcrel tglobaladdr:$addr), u32_0ImmPred:$I), (PS_call_instrprof_custom tglobaladdr:$addr, imm:$I)>; diff --git a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td index 1637b91..d19920c 100644 --- a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td +++ b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td @@ -612,6 +612,9 @@ let Predicates = [UseHVX] in { (V6_vandvrt HvxVR:$Vs, (ToI32 0x01010101))>; def: Pat<(VecQ32 (trunc HVI32:$Vs)), (V6_vandvrt HvxVR:$Vs, (ToI32 0x01010101))>; + def: Pat<(VecQ16 (trunc HWI32:$Vss)), + (Combineq(VecQ32(V6_vandvrt (HiVec $Vss), (ToI32 0x01010101))), + (VecQ32 (V6_vandvrt (LoVec $Vss), (ToI32 0x01010101))))>; } let Predicates = [UseHVX] in { diff --git a/llvm/lib/Target/Hexagon/HexagonSubtarget.cpp b/llvm/lib/Target/Hexagon/HexagonSubtarget.cpp index b9cdd6a..ce2de75 100644 --- a/llvm/lib/Target/Hexagon/HexagonSubtarget.cpp +++ b/llvm/lib/Target/Hexagon/HexagonSubtarget.cpp @@ -544,7 +544,7 @@ int HexagonSubtarget::updateLatency(MachineInstr &SrcInst, if (!hasV60Ops()) return Latency; - auto &QII = static_cast<const HexagonInstrInfo &>(*getInstrInfo()); + const HexagonInstrInfo &QII = *getInstrInfo(); // BSB scheduling. if (QII.isHVXVec(SrcInst) || useBSBScheduling()) Latency = (Latency + 1) >> 1; diff --git a/llvm/lib/Target/Hexagon/HexagonTfrCleanup.cpp b/llvm/lib/Target/Hexagon/HexagonTfrCleanup.cpp index 71bdfc66..5a85f34 100644 --- a/llvm/lib/Target/Hexagon/HexagonTfrCleanup.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTfrCleanup.cpp @@ -43,7 +43,7 @@ namespace { class HexagonTfrCleanup : public MachineFunctionPass { public: static char ID; - HexagonTfrCleanup() : MachineFunctionPass(ID), HII(0), TRI(0) {} + HexagonTfrCleanup() : MachineFunctionPass(ID) {} StringRef getPassName() const override { return "Hexagon TFR Cleanup"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); @@ -52,8 +52,8 @@ public: bool runOnMachineFunction(MachineFunction &MF) override; private: - const HexagonInstrInfo *HII; - const TargetRegisterInfo *TRI; + const HexagonInstrInfo *HII = nullptr; + const TargetRegisterInfo *TRI = nullptr; typedef DenseMap<unsigned, uint64_t> ImmediateMap; diff --git a/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td index 690dd73..e86b21c 100644 --- a/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td @@ -365,6 +365,7 @@ def : Pat<(f32 (uint_to_fp (i64 (sexti32 (i64 GPR:$src))))), // FP Rounding let Predicates = [HasBasicF, IsLA64] in { def : PatFpr<frint, FRINT_S, FPR32>; +def : PatFpr<flog2, FLOGB_S, FPR32>; } // Predicates = [HasBasicF, IsLA64] let Predicates = [HasBasicF, IsLA32] in { diff --git a/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td index daefbaa..2e88254 100644 --- a/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td @@ -348,6 +348,7 @@ def : Pat<(bitconvert FPR64:$src), (MOVFR2GR_D FPR64:$src)>; // FP Rounding let Predicates = [HasBasicD, IsLA64] in { def : PatFpr<frint, FRINT_D, FPR64>; +def : PatFpr<flog2, FLOGB_D, FPR64>; } // Predicates = [HasBasicD, IsLA64] /// Pseudo-instructions needed for the soft-float ABI with LA32D diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 80c96c6..a6de839 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -244,8 +244,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FP_TO_BF16, MVT::f32, Subtarget.isSoftFPABI() ? LibCall : Custom); - if (Subtarget.is64Bit()) + if (Subtarget.is64Bit()) { setOperationAction(ISD::FRINT, MVT::f32, Legal); + setOperationAction(ISD::FLOG2, MVT::f32, Legal); + } if (!Subtarget.hasBasicD()) { setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom); @@ -291,8 +293,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FP_TO_BF16, MVT::f64, Subtarget.isSoftFPABI() ? LibCall : Custom); - if (Subtarget.is64Bit()) + if (Subtarget.is64Bit()) { setOperationAction(ISD::FRINT, MVT::f64, Legal); + setOperationAction(ISD::FLOG2, MVT::f64, Legal); + } } // Set operations for 'LSX' feature. @@ -362,6 +366,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FMA, VT, Legal); setOperationAction(ISD::FSQRT, VT, Legal); setOperationAction(ISD::FNEG, VT, Legal); + setOperationAction(ISD::FLOG2, VT, Legal); setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT, ISD::SETUGE, ISD::SETUGT}, VT, Expand); @@ -443,6 +448,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FMA, VT, Legal); setOperationAction(ISD::FSQRT, VT, Legal); setOperationAction(ISD::FNEG, VT, Legal); + setOperationAction(ISD::FLOG2, VT, Legal); setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT, ISD::SETUGE, ISD::SETUGT}, VT, Expand); diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td index 613dea6..ca4ee5f 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td @@ -1593,6 +1593,9 @@ def : Pat<(fma_nsz (fneg v4f64:$xj), v4f64:$xk, v4f64:$xa), // XVFSQRT_{S/D} defm : PatXrF<fsqrt, "XVFSQRT">; +// XVFLOGB_{S/D} +defm : PatXrF<flog2, "XVFLOGB">; + // XVRECIP_{S/D} def : Pat<(fdiv vsplatf32_fpimm_eq_1, v8f32:$xj), (XVFRECIP_S v8f32:$xj)>; @@ -2024,6 +2027,24 @@ def : Pat<(v4i32(fp_to_uint v4f64:$vj)), (XVFTINTRZ_LU_D v4f64:$vj)), sub_128)>; +// XVAVG_{B/H/W/D/BU/HU/WU/DU}, XVAVGR_{B/H/W/D/BU/HU/WU/DU} +defm : VAvgPat<sra, "XVAVG_B", v32i8>; +defm : VAvgPat<sra, "XVAVG_H", v16i16>; +defm : VAvgPat<sra, "XVAVG_W", v8i32>; +defm : VAvgPat<sra, "XVAVG_D", v4i64>; +defm : VAvgPat<srl, "XVAVG_BU", v32i8>; +defm : VAvgPat<srl, "XVAVG_HU", v16i16>; +defm : VAvgPat<srl, "XVAVG_WU", v8i32>; +defm : VAvgPat<srl, "XVAVG_DU", v4i64>; +defm : VAvgrPat<sra, "XVAVGR_B", v32i8>; +defm : VAvgrPat<sra, "XVAVGR_H", v16i16>; +defm : VAvgrPat<sra, "XVAVGR_W", v8i32>; +defm : VAvgrPat<sra, "XVAVGR_D", v4i64>; +defm : VAvgrPat<srl, "XVAVGR_BU", v32i8>; +defm : VAvgrPat<srl, "XVAVGR_HU", v16i16>; +defm : VAvgrPat<srl, "XVAVGR_WU", v8i32>; +defm : VAvgrPat<srl, "XVAVGR_DU", v4i64>; + // abs def : Pat<(abs v32i8:$xj), (XVSIGNCOV_B v32i8:$xj, v32i8:$xj)>; def : Pat<(abs v16i16:$xj), (XVSIGNCOV_H v16i16:$xj, v16i16:$xj)>; diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td index 4619c6b..92402ba 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td @@ -1518,6 +1518,18 @@ multiclass InsertExtractPatV2<ValueType vecty, ValueType elemty> { } } +multiclass VAvgPat<SDPatternOperator OpNode, string Inst, ValueType vt> { + def : Pat<(OpNode (vt (add vt:$vj, vt:$vk)), (vt (vsplat_imm_eq_1))), + (!cast<LAInst>(Inst) vt:$vj, vt:$vk)>; +} + +multiclass VAvgrPat<SDPatternOperator OpNode, string Inst, ValueType vt> { + def : Pat<(OpNode (vt (add (vt (add vt:$vj, vt:$vk)), + (vt (vsplat_imm_eq_1)))), + (vt (vsplat_imm_eq_1))), + (!cast<LAInst>(Inst) vt:$vj, vt:$vk)>; +} + let Predicates = [HasExtLSX] in { // VADD_{B/H/W/D} @@ -1783,6 +1795,9 @@ def : Pat<(fma_nsz (fneg v2f64:$vj), v2f64:$vk, v2f64:$va), // VFSQRT_{S/D} defm : PatVrF<fsqrt, "VFSQRT">; +// VFLOGB_{S/D} +defm : PatVrF<flog2, "VFLOGB">; + // VFRECIP_{S/D} def : Pat<(fdiv vsplatf32_fpimm_eq_1, v4f32:$vj), (VFRECIP_S v4f32:$vj)>; @@ -2154,6 +2169,24 @@ def : Pat<(f32 f32imm_vldi:$in), def : Pat<(f64 f64imm_vldi:$in), (f64 (EXTRACT_SUBREG (VLDI (to_f64imm_vldi f64imm_vldi:$in)), sub_64))>; +// VAVG_{B/H/W/D/BU/HU/WU/DU}, VAVGR_{B/H/W/D/BU/HU/WU/DU} +defm : VAvgPat<sra, "VAVG_B", v16i8>; +defm : VAvgPat<sra, "VAVG_H", v8i16>; +defm : VAvgPat<sra, "VAVG_W", v4i32>; +defm : VAvgPat<sra, "VAVG_D", v2i64>; +defm : VAvgPat<srl, "VAVG_BU", v16i8>; +defm : VAvgPat<srl, "VAVG_HU", v8i16>; +defm : VAvgPat<srl, "VAVG_WU", v4i32>; +defm : VAvgPat<srl, "VAVG_DU", v2i64>; +defm : VAvgrPat<sra, "VAVGR_B", v16i8>; +defm : VAvgrPat<sra, "VAVGR_H", v8i16>; +defm : VAvgrPat<sra, "VAVGR_W", v4i32>; +defm : VAvgrPat<sra, "VAVGR_D", v2i64>; +defm : VAvgrPat<srl, "VAVGR_BU", v16i8>; +defm : VAvgrPat<srl, "VAVGR_HU", v8i16>; +defm : VAvgrPat<srl, "VAVGR_WU", v4i32>; +defm : VAvgrPat<srl, "VAVGR_DU", v2i64>; + // abs def : Pat<(abs v16i8:$vj), (VSIGNCOV_B v16i8:$vj, v16i8:$vj)>; def : Pat<(abs v8i16:$vj), (VSIGNCOV_H v8i16:$vj, v8i16:$vj)>; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 7e7ee75..c667a09 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1871,17 +1871,6 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; } (is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, _CH)) \ : (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, ))) -#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode, is_mc, is_ch, is_s32) \ - [&]() -> auto { \ - if (is_mc && is_ch) \ - return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC_CH); \ - if (is_ch) \ - return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _CH); \ - if (is_mc) \ - return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC); \ - return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, ); \ - }() - static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim, bool IsShared32, bool IsCacheHint, @@ -1925,112 +1914,6 @@ static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim, } } -static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32, - bool IsMultiCast, - bool IsCacheHint, bool IsIm2Col) { - if (IsIm2Col) { - switch (Dim) { - case 3: - return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL, IsMultiCast, - IsCacheHint, IsShared32); - case 4: - return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL, IsMultiCast, - IsCacheHint, IsShared32); - case 5: - return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL, IsMultiCast, - IsCacheHint, IsShared32); - default: - llvm_unreachable("Invalid Dimension in im2col mode for " - "GetCpAsyncBulkTensorG2SOpcode."); - } - } else { - switch (Dim) { - case 1: - return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE, IsMultiCast, - IsCacheHint, IsShared32); - case 2: - return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE, IsMultiCast, - IsCacheHint, IsShared32); - case 3: - return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE, IsMultiCast, - IsCacheHint, IsShared32); - case 4: - return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE, IsMultiCast, - IsCacheHint, IsShared32); - case 5: - return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE, IsMultiCast, - IsCacheHint, IsShared32); - default: - llvm_unreachable( - "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode."); - } - } -} - -static size_t GetDimsFromIntrinsic(unsigned IID) { - switch (IID) { - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d: - case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d: - return 3; - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d: - case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d: - return 4; - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d: - case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d: - return 5; - default: - llvm_unreachable("Invalid im2col intrinsic in GetDimsFromIntrinsic."); - } -} - -void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N, - bool IsIm2Col) { - // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: - // {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2} - // multicast, cache_hint, - // multicast_flag, cache_hint_flag, cta_group_flag} - // NumOperands = {Chain, IID} + {Actual intrinsic args} - // = {2} + {8 + dims + im2col_offsets} - size_t NumOps = N->getNumOperands(); - size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1)) - : (NumOps - 10); - // Offsets is always 'NumDims - 2' and only for im2col mode - size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0; - bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1; - bool IsMultiCast = N->getConstantOperandVal(NumOps - 3) == 1; - size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src} - size_t MultiCastIdx = NumBaseArgs + 2; // for Chain and IID - - unsigned CTAGroupVal = N->getConstantOperandVal(NumOps - 1); - if ((CTAGroupVal > 0) && !Subtarget->hasCpAsyncBulkTensorCTAGroupSupport()) - report_fatal_error( - formatv("CpAsyncBulkTensorG2S cta_group::1/2 is not supported on sm_{}", - Subtarget->getSmVersion())); - - SDLoc DL(N); - SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs)); - - // Push MultiCast operand, if available - if (IsMultiCast) - Ops.push_back(N->getOperand(MultiCastIdx)); - - // Push CacheHint operand, if available - if (IsCacheHint) - Ops.push_back(N->getOperand(MultiCastIdx + 1)); - - // Flag for CTA Group - Ops.push_back(getI32Imm(CTAGroupVal, DL)); - - // Finally, the chain operand - Ops.push_back(N->getOperand(0)); - - bool IsShared32 = - CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32; - unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode( - NumDims, IsShared32, IsMultiCast, IsCacheHint, IsIm2Col); - ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); -} - void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp, bool IsIm2Col) { @@ -2175,18 +2058,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) { switch (IID) { default: return false; - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d: - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d: - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d: - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d: - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: - SelectCpAsyncBulkTensorG2SCommon(N); - return true; - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d: - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d: - case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d: - SelectCpAsyncBulkTensorG2SCommon(N, /*IsIm2Col=*/true); - return true; case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d: case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d: case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d: diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index c912e70..1cb579b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -86,7 +86,6 @@ private: bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N); void SelectV2I64toI128(SDNode *N); void SelectI128toV2I64(SDNode *N); - void SelectCpAsyncBulkTensorG2SCommon(SDNode *N, bool IsIm2Col = false); void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp, bool IsIm2Col = false); void SelectTcgen05Ld(SDNode *N, bool hasOffset = false); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index dfde0cc..b260221 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -139,7 +139,6 @@ def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">; def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">; def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">; def hasTcgen05MMAScaleInputDImm : Predicate<"Subtarget->hasTcgen05MMAScaleInputDImm()">; -def hasTMACTAGroupSupport : Predicate<"Subtarget->hasCpAsyncBulkTensorCTAGroupSupport()">; def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">; class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index c923f0e..e8758aa 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -599,75 +599,15 @@ class TMA_IM2COL_UTIL<int dim, string mode> { string base_str = !interleave(!foreach(i, !range(offsets), "$im2col" # i), ", "); } -// From Global to Shared memory (G2S) -class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> { - string prefix = "cp.async.bulk.tensor"; - string dir = "shared::cluster.global"; - string completion = "mbarrier::complete_tx::bytes"; - string inst_name = prefix - # "." # dim # "d" - # "." # dir - # "." # mode - # "." # completion - # !if(mc, ".multicast::cluster", "") - # !if(ch, ".L2::cache_hint", ""); - string intr_name = "CP_ASYNC_BULK_TENSOR_G2S_" - # dim # "D" - # !if(is_shared32, "_SHARED32", "") - # !if(!eq(mode, "tile"), "_TILE", "_IM2COL"); -} - def CTAGroupFlags : Operand<i32> { let PrintMethod = "printCTAGroup"; } -multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> { - defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; - defvar dims_str = TMA_DIMS_UTIL<dim>.base_str; - defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]"; - defvar rc = !if(is_shared32, B32, B64); - - defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0); - defvar im2col_dag = !if(!eq(mode, "im2col"), - !dag(ins, !listsplat(B16, num_im2col), !foreach(i, !range(num_im2col), "im2col" # i)), - (ins)); - defvar im2col_str = !interleave(!foreach(i, !range(num_im2col), "$im2col" # i), ", "); - defvar im2col_asm_str = ", {{" # im2col_str # "}}"; - - defvar asm_str = !if(!eq(mode, "im2col"), - !strconcat(asm_str_default, im2col_asm_str), asm_str_default); +def tma_cta_group_imm0 : TImmLeaf<i32, [{return Imm == 0;}]>; +def tma_cta_group_imm_any : TImmLeaf<i32, [{return Imm >= 0;}]>; - def "" : NVPTXInst<(outs), - !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)), - !strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";")>, - Requires<[hasPTX<80>, hasSM<90>]>; - def _MC : NVPTXInst<(outs), - !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, - (ins B16:$mc, CTAGroupFlags:$cg)), - !strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;")>, - Requires<[hasPTX<80>, hasSM<90>]>; - def _CH : NVPTXInst<(outs), - !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, - (ins B64:$ch, CTAGroupFlags:$cg)), - !strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;")>, - Requires<[hasPTX<80>, hasSM<90>]>; - def _MC_CH : NVPTXInst<(outs), - !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, - (ins B16:$mc, B64:$ch, CTAGroupFlags:$cg)), - !strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;")>, - Requires<[hasPTX<80>, hasSM<90>]>; -} - -foreach dim = [1, 2, 3, 4, 5] in { - foreach shared32 = [true, false] in { - foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in { - defm G2S_STRINGS<dim, mode, 0, 0, shared32>.intr_name : - CP_ASYNC_BULK_TENSOR_G2S_INTR<dim, shared32, mode>; - } - } -} - -multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> { +multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred, + TImmLeaf cta_group_type = tma_cta_group_imm_any> { defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; defvar dims_str = TMA_DIMS_UTIL<dim>.base_str; defvar asm_str_base = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]"; @@ -697,10 +637,10 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> !setdagop(dims_dag, intr), !setdagop(im2col_dag, intr), (intr B16:$mc, B64:$ch)); - defvar intr_dag_no_hints = !con(intr_dag_base, (intr 0, 0, timm:$cg)); - defvar intr_dag_with_mc = !con(intr_dag_base, (intr -1, 0, timm:$cg)); - defvar intr_dag_with_ch = !con(intr_dag_base, (intr 0, -1, timm:$cg)); - defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, timm:$cg)); + defvar intr_dag_no_hints = !con(intr_dag_base, (intr 0, 0, cta_group_type:$cg)); + defvar intr_dag_with_mc = !con(intr_dag_base, (intr -1, 0, cta_group_type:$cg)); + defvar intr_dag_with_ch = !con(intr_dag_base, (intr 0, -1, cta_group_type:$cg)); + defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, cta_group_type:$cg)); def "" : NVPTXInst<(outs), ins_dag, inst_name # asm_str # ";", @@ -719,14 +659,30 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> [intr_dag_with_mc_ch]>, Requires<pred>; } + +foreach dim = 1...5 in { + defm TMA_G2S_TILE_CG0_ # dim # "D" + : TMA_TENSOR_G2S_INTR<dim, "tile", [hasPTX<80>, hasSM<90>], + tma_cta_group_imm0>; + defm TMA_G2S_TILE_ # dim # "D" + : TMA_TENSOR_G2S_INTR<dim, "tile", + [callSubtarget<"hasTMABlackwellSupport">]>; +} foreach dim = 3...5 in { + defm TMA_G2S_IM2COL_CG0_ # dim # "D" + : TMA_TENSOR_G2S_INTR<dim, "im2col", [hasPTX<80>, hasSM<90>], + tma_cta_group_imm0>; + defm TMA_G2S_IM2COL_ # dim # "D" + : TMA_TENSOR_G2S_INTR<dim, "im2col", + [callSubtarget<"hasTMABlackwellSupport">]>; foreach mode = ["im2col_w", "im2col_w_128"] in { defm TMA_G2S_ # !toupper(mode) # "_" # dim # "D" - : TMA_TENSOR_G2S_INTR<dim, mode, [hasTMACTAGroupSupport]>; + : TMA_TENSOR_G2S_INTR<dim, mode, + [callSubtarget<"hasTMABlackwellSupport">]>; } } defm TMA_G2S_TILE_GATHER4_2D : TMA_TENSOR_G2S_INTR<5, "tile_gather4", - [hasTMACTAGroupSupport]>; + [callSubtarget<"hasTMABlackwellSupport">]>; multiclass TMA_TENSOR_G2S_CTA_INTR<int dim, string mode, list<Predicate> pred = []> { defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; @@ -784,7 +740,8 @@ foreach dim = 3...5 in { : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w", [hasPTX<86>, hasSM<100>]>; defm TMA_G2S_CTA_IM2COL_W_128_ # dim # "D" - : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", [hasTMACTAGroupSupport]>; + : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", + [callSubtarget<"hasTMABlackwellSupport">]>; } defm TMA_G2S_CTA_TILE_GATHER4_2D : TMA_TENSOR_G2S_CTA_INTR<5, "tile_gather4", [hasPTX<86>, hasSM<100>]>; @@ -835,7 +792,7 @@ foreach dim = 1...5 in { } } defm TMA_S2G_TILE_SCATTER4_2D : TMA_TENSOR_S2G_INTR<5, "tile_scatter4", - [hasTMACTAGroupSupport]>; + [callSubtarget<"hasTMABlackwellSupport">]>; def TMAReductionFlags : Operand<i32> { let PrintMethod = "printTmaReductionMode"; @@ -930,11 +887,11 @@ foreach dim = 3...5 in { foreach mode = ["im2col_w", "im2col_w_128"] in { defvar suffix = !toupper(mode) # "_" # dim # "D"; defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode, - [hasTMACTAGroupSupport]>; + [callSubtarget<"hasTMABlackwellSupport">]>; } } defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4", - [hasTMACTAGroupSupport]>; + [callSubtarget<"hasTMABlackwellSupport">]>; //Prefetchu and Prefetch diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index 194dbdc..021b1f6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -166,18 +166,15 @@ public: // f32x2 instructions in Blackwell family bool hasF32x2Instructions() const; - // TMA G2S copy with cta_group::1/2 support - bool hasCpAsyncBulkTensorCTAGroupSupport() const { - // TODO: Update/tidy-up after the family-conditional support arrives - switch (FullSmVersion) { - case 1003: - case 1013: - return PTXVersion >= 86; - case 1033: - return PTXVersion >= 88; - default: - return false; - } + // Checks support for following in TMA: + // - cta_group::1/2 support + // - im2col_w/w_128 mode support + // - tile_gather4 mode support + // - tile_scatter4 mode support + bool hasTMABlackwellSupport() const { + return hasPTXWithFamilySMs(90, {100, 110}) || + hasPTXWithFamilySMs(88, {100, 101}) || + hasPTXWithAccelSMs(86, {100, 101}); } // Prior to CUDA 12.3 ptxas did not recognize that the trap instruction diff --git a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp index 3640d25..70df59d 100644 --- a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp +++ b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp @@ -1316,7 +1316,7 @@ bool PPCLoopInstrFormPrep::runOnLoop(Loop *L) { // useless and possible to break some original well-form addressing mode // to make this pre-inc prep for it. if (PointerElementType->isIntegerTy(64)) { - const SCEV *LSCEV = SE->getSCEVAtScope(const_cast<Value *>(PtrValue), L); + const SCEV *LSCEV = SE->getSCEVAtScope(PtrValue, L); const SCEVAddRecExpr *LARSCEV = dyn_cast<SCEVAddRecExpr>(LSCEV); if (!LARSCEV || LARSCEV->getLoop() != L) return false; diff --git a/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp b/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp index 000d296..4ff489d 100644 --- a/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp @@ -296,8 +296,9 @@ PPCTargetMachine::PPCTargetMachine(const Target &T, const Triple &TT, std::optional<Reloc::Model> RM, std::optional<CodeModel::Model> CM, CodeGenOptLevel OL, bool JIT) - : CodeGenTargetMachineImpl(T, TT.computeDataLayout(), TT, CPU, - computeFSAdditions(FS, OL, TT), Options, + : CodeGenTargetMachineImpl(T, + TT.computeDataLayout(Options.MCOptions.ABIName), + TT, CPU, computeFSAdditions(FS, OL, TT), Options, getEffectiveRelocModel(TT, RM), getEffectivePPCCodeModel(TT, CM, JIT), OL), TLOF(createTLOF(getTargetTriple())), diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp index 8198173..282cf5d 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp @@ -92,6 +92,10 @@ private: void emitFence(AtomicOrdering FenceOrdering, SyncScope::ID FenceSSID, MachineIRBuilder &MIB) const; bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const; + void addVectorLoadStoreOperands(MachineInstr &I, + SmallVectorImpl<SrcOp> &SrcOps, + unsigned &CurOp, bool IsMasked, + bool IsStrided) const; bool selectIntrinsicWithSideEffects(MachineInstr &I, MachineIRBuilder &MIB) const; @@ -716,6 +720,26 @@ static unsigned selectRegImmLoadStoreOp(unsigned GenericOpc, unsigned OpSize) { return GenericOpc; } +void RISCVInstructionSelector::addVectorLoadStoreOperands( + MachineInstr &I, SmallVectorImpl<SrcOp> &SrcOps, unsigned &CurOp, + bool IsMasked, bool IsStrided) const { + // Base Pointer + auto PtrReg = I.getOperand(CurOp++).getReg(); + SrcOps.push_back(PtrReg); + + // Stride + if (IsStrided) { + auto StrideReg = I.getOperand(CurOp++).getReg(); + SrcOps.push_back(StrideReg); + } + + // Mask + if (IsMasked) { + auto MaskReg = I.getOperand(CurOp++).getReg(); + SrcOps.push_back(MaskReg); + } +} + bool RISCVInstructionSelector::selectIntrinsicWithSideEffects( MachineInstr &I, MachineIRBuilder &MIB) const { // Find the intrinsic ID. @@ -752,21 +776,7 @@ bool RISCVInstructionSelector::selectIntrinsicWithSideEffects( SrcOps.push_back(Register(RISCV::NoRegister)); } - // Base Pointer - auto PtrReg = I.getOperand(CurOp++).getReg(); - SrcOps.push_back(PtrReg); - - // Stride - if (IsStrided) { - auto StrideReg = I.getOperand(CurOp++).getReg(); - SrcOps.push_back(StrideReg); - } - - // Mask - if (IsMasked) { - auto MaskReg = I.getOperand(CurOp++).getReg(); - SrcOps.push_back(MaskReg); - } + addVectorLoadStoreOperands(I, SrcOps, CurOp, IsMasked, IsStrided); RISCVVType::VLMUL LMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(VT)); const RISCV::VLEPseudo *P = @@ -795,6 +805,48 @@ bool RISCVInstructionSelector::selectIntrinsicWithSideEffects( I.eraseFromParent(); return constrainSelectedInstRegOperands(*PseudoMI, TII, TRI, RBI); } + case Intrinsic::riscv_vsm: + case Intrinsic::riscv_vse: + case Intrinsic::riscv_vse_mask: + case Intrinsic::riscv_vsse: + case Intrinsic::riscv_vsse_mask: { + bool IsMasked = IntrinID == Intrinsic::riscv_vse_mask || + IntrinID == Intrinsic::riscv_vsse_mask; + bool IsStrided = IntrinID == Intrinsic::riscv_vsse || + IntrinID == Intrinsic::riscv_vsse_mask; + LLT VT = MRI->getType(I.getOperand(1).getReg()); + unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits()); + + // Sources + unsigned CurOp = 1; + SmallVector<SrcOp, 4> SrcOps; // Source registers. + + // Store value + auto PassthruReg = I.getOperand(CurOp++).getReg(); + SrcOps.push_back(PassthruReg); + + addVectorLoadStoreOperands(I, SrcOps, CurOp, IsMasked, IsStrided); + + RISCVVType::VLMUL LMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(VT)); + const RISCV::VSEPseudo *P = RISCV::getVSEPseudo( + IsMasked, IsStrided, Log2SEW, static_cast<unsigned>(LMUL)); + + auto PseudoMI = MIB.buildInstr(P->Pseudo, {}, SrcOps); + + // Select VL + auto VLOpFn = renderVLOp(I.getOperand(CurOp++)); + for (auto &RenderFn : *VLOpFn) + RenderFn(PseudoMI); + + // SEW + PseudoMI.addImm(Log2SEW); + + // Memref + PseudoMI.cloneMemRefs(I); + + I.eraseFromParent(); + return constrainSelectedInstRegOperands(*PseudoMI, TII, TRI, RBI); + } } } diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp index 4105618..526675a 100644 --- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp +++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp @@ -127,6 +127,10 @@ bool RISCVExpandPseudo::expandMI(MachineBasicBlock &MBB, case RISCV::PseudoCCAND: case RISCV::PseudoCCOR: case RISCV::PseudoCCXOR: + case RISCV::PseudoCCMAX: + case RISCV::PseudoCCMAXU: + case RISCV::PseudoCCMIN: + case RISCV::PseudoCCMINU: case RISCV::PseudoCCADDW: case RISCV::PseudoCCSUBW: case RISCV::PseudoCCSLL: @@ -217,6 +221,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB, .addImm(0); } else { unsigned NewOpc; + // clang-format off switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected opcode!"); @@ -228,6 +233,10 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB, case RISCV::PseudoCCAND: NewOpc = RISCV::AND; break; case RISCV::PseudoCCOR: NewOpc = RISCV::OR; break; case RISCV::PseudoCCXOR: NewOpc = RISCV::XOR; break; + case RISCV::PseudoCCMAX: NewOpc = RISCV::MAX; break; + case RISCV::PseudoCCMIN: NewOpc = RISCV::MIN; break; + case RISCV::PseudoCCMAXU: NewOpc = RISCV::MAXU; break; + case RISCV::PseudoCCMINU: NewOpc = RISCV::MINU; break; case RISCV::PseudoCCADDI: NewOpc = RISCV::ADDI; break; case RISCV::PseudoCCSLLI: NewOpc = RISCV::SLLI; break; case RISCV::PseudoCCSRLI: NewOpc = RISCV::SRLI; break; @@ -250,6 +259,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB, case RISCV::PseudoCCNDS_BFOS: NewOpc = RISCV::NDS_BFOS; break; case RISCV::PseudoCCNDS_BFOZ: NewOpc = RISCV::NDS_BFOZ; break; } + // clang-format on if (NewOpc == RISCV::NDS_BFOZ || NewOpc == RISCV::NDS_BFOS) { BuildMI(TrueBB, DL, TII->get(NewOpc), DestReg) diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index b4556f6..cfee6ab 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -1851,6 +1851,11 @@ def TuneShortForwardBranchOpt def HasShortForwardBranchOpt : Predicate<"Subtarget->hasShortForwardBranchOpt()">; def NoShortForwardBranchOpt : Predicate<"!Subtarget->hasShortForwardBranchOpt()">; +def TuneShortForwardBranchIMinMax + : SubtargetFeature<"short-forward-branch-i-minmax", "HasShortForwardBranchIMinMax", + "true", "Enable short forward branch optimization for min,max instructions in Zbb", + [TuneShortForwardBranchOpt]>; + // Some subtargets require a S2V transfer buffer to move scalars into vectors. // FIXME: Forming .vx/.vf/.wx/.wf can reduce register pressure. def TuneNoSinkSplatOperands diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 9a6afa1..b25a054 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3995,6 +3995,7 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits, case RISCV::CTZW: case RISCV::CPOPW: case RISCV::SLLI_UW: + case RISCV::ABSW: case RISCV::FMV_W_X: case RISCV::FCVT_H_W: case RISCV::FCVT_H_W_INX: diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 1c930ac..c6a8b84 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -433,6 +433,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.hasStdExtP() || (Subtarget.hasVendorXCValu() && !Subtarget.is64Bit())) { setOperationAction(ISD::ABS, XLenVT, Legal); + if (Subtarget.is64Bit()) + setOperationAction(ISD::ABS, MVT::i32, Custom); } else if (Subtarget.hasShortForwardBranchOpt()) { // We can use PseudoCCSUB to implement ABS. setOperationAction(ISD::ABS, XLenVT, Legal); @@ -14816,8 +14818,16 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); + if (Subtarget.hasStdExtP()) { + SDValue Src = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0)); + SDValue Abs = DAG.getNode(RISCVISD::ABSW, DL, MVT::i64, Src); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Abs)); + return; + } + if (Subtarget.hasStdExtZbb()) { - // Emit a special ABSW node that will be expanded to NEGW+MAX at isel. + // Emit a special node that will be expanded to NEGW+MAX at isel. // This allows us to remember that the result is sign extended. Expanding // to NEGW+MAX here requires a Freeze which breaks ComputeNumSignBits. SDValue Src = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, @@ -19784,7 +19794,9 @@ legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index, // LLVM's legalization take care of the splitting. // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet. Index = DAG.getNode(ISD::SIGN_EXTEND, DL, - IndexVT.changeVectorElementType(XLenVT), Index); + EVT::getVectorVT(*DAG.getContext(), XLenVT, + IndexVT.getVectorElementCount()), + Index); } IndexType = ISD::UNSIGNED_SCALED; return true; @@ -20290,6 +20302,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, break; } + case RISCVISD::ABSW: case RISCVISD::CLZW: case RISCVISD::CTZW: { // Only the lower 32 bits of the first operand are read @@ -21862,6 +21875,7 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode( case RISCVISD::REMUW: case RISCVISD::ROLW: case RISCVISD::RORW: + case RISCVISD::ABSW: case RISCVISD::FCVT_W_RV64: case RISCVISD::FCVT_WU_RV64: case RISCVISD::STRICT_FCVT_W_RV64: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 912b82d..c9df787 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -869,7 +869,7 @@ std::optional<unsigned> getFoldedOpcode(MachineFunction &MF, MachineInstr &MI, } } -// This is the version used during inline spilling +// This is the version used during InlineSpiller::spillAroundUses MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl( MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned> Ops, MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS, @@ -1699,6 +1699,10 @@ unsigned getPredicatedOpcode(unsigned Opcode) { case RISCV::AND: return RISCV::PseudoCCAND; case RISCV::OR: return RISCV::PseudoCCOR; case RISCV::XOR: return RISCV::PseudoCCXOR; + case RISCV::MAX: return RISCV::PseudoCCMAX; + case RISCV::MAXU: return RISCV::PseudoCCMAXU; + case RISCV::MIN: return RISCV::PseudoCCMIN; + case RISCV::MINU: return RISCV::PseudoCCMINU; case RISCV::ADDI: return RISCV::PseudoCCADDI; case RISCV::SLLI: return RISCV::PseudoCCSLLI; @@ -1735,7 +1739,8 @@ unsigned getPredicatedOpcode(unsigned Opcode) { /// return the defining instruction. static MachineInstr *canFoldAsPredicatedOp(Register Reg, const MachineRegisterInfo &MRI, - const TargetInstrInfo *TII) { + const TargetInstrInfo *TII, + const RISCVSubtarget &STI) { if (!Reg.isVirtual()) return nullptr; if (!MRI.hasOneNonDBGUse(Reg)) @@ -1743,6 +1748,12 @@ static MachineInstr *canFoldAsPredicatedOp(Register Reg, MachineInstr *MI = MRI.getVRegDef(Reg); if (!MI) return nullptr; + + if (!STI.hasShortForwardBranchIMinMax() && + (MI->getOpcode() == RISCV::MAX || MI->getOpcode() == RISCV::MIN || + MI->getOpcode() == RISCV::MINU || MI->getOpcode() == RISCV::MAXU)) + return nullptr; + // Check if MI can be predicated and folded into the CCMOV. if (getPredicatedOpcode(MI->getOpcode()) == RISCV::INSTRUCTION_LIST_END) return nullptr; @@ -1806,10 +1817,10 @@ RISCVInstrInfo::optimizeSelect(MachineInstr &MI, MachineRegisterInfo &MRI = MI.getParent()->getParent()->getRegInfo(); MachineInstr *DefMI = - canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this); + canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this, STI); bool Invert = !DefMI; if (!DefMI) - DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this); + DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this, STI); if (!DefMI) return nullptr; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index 7c89686..9cb53fb 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -768,7 +768,7 @@ def BGE : BranchCC_rri<0b101, "bge">; def BLTU : BranchCC_rri<0b110, "bltu">; def BGEU : BranchCC_rri<0b111, "bgeu">; -let IsSignExtendingOpW = 1 in { +let IsSignExtendingOpW = 1, canFoldAsLoad = 1 in { def LB : Load_ri<0b000, "lb">, Sched<[WriteLDB, ReadMemBase]>; def LH : Load_ri<0b001, "lh">, Sched<[WriteLDH, ReadMemBase]>; def LW : Load_ri<0b010, "lw">, Sched<[WriteLDW, ReadMemBase]>; @@ -889,8 +889,10 @@ def CSRRCI : CSR_ii<0b111, "csrrci">; /// RV64I instructions let Predicates = [IsRV64] in { +let canFoldAsLoad = 1 in { def LWU : Load_ri<0b110, "lwu">, Sched<[WriteLDW, ReadMemBase]>; def LD : Load_ri<0b011, "ld">, Sched<[WriteLDD, ReadMemBase]>; +} def SD : Store_rri<0b011, "sd">, Sched<[WriteSTD, ReadStoreData, ReadMemBase]>; let IsSignExtendingOpW = 1 in { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td index afac37d..4ffe3e6 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td @@ -71,6 +71,7 @@ defvar DExtsRV64 = [DExt, ZdinxExt]; //===----------------------------------------------------------------------===// let Predicates = [HasStdExtD] in { +let canFoldAsLoad = 1 in def FLD : FPLoad_r<0b011, "fld", FPR64, WriteFLD64>; // Operands for stores are in the order srcreg, base, offset rather than diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td index 6571d99..b30f8ec 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td @@ -330,6 +330,7 @@ class PseudoFROUND<DAGOperand Ty, ValueType vt, ValueType intvt = XLenVT> //===----------------------------------------------------------------------===// let Predicates = [HasStdExtF] in { +let canFoldAsLoad = 1 in def FLW : FPLoad_r<0b010, "flw", FPR32, WriteFLD32>; // Operands for stores are in the order srcreg, base, offset rather than diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index cc085bb..4cbbba3 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -1461,5 +1461,10 @@ let Predicates = [HasStdExtP, IsRV32] in { // Codegen patterns //===----------------------------------------------------------------------===// +def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>; + let Predicates = [HasStdExtP] in def : PatGpr<abs, ABS>; + +let Predicates = [HasStdExtP, IsRV64] in +def : PatGpr<riscv_absw, ABSW>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td index 0114fbd..5a67a5a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td @@ -106,6 +106,10 @@ def PseudoCCSRA : SFBALU_rr; def PseudoCCAND : SFBALU_rr; def PseudoCCOR : SFBALU_rr; def PseudoCCXOR : SFBALU_rr; +def PseudoCCMAX : SFBALU_rr; +def PseudoCCMIN : SFBALU_rr; +def PseudoCCMAXU : SFBALU_rr; +def PseudoCCMINU : SFBALU_rr; def PseudoCCADDI : SFBALU_ri; def PseudoCCANDI : SFBALU_ri; diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp index d08115b..ea98cdb 100644 --- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp +++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp @@ -172,6 +172,7 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI, case RISCV::CTZW: case RISCV::CPOPW: case RISCV::SLLI_UW: + case RISCV::ABSW: case RISCV::FMV_W_X: case RISCV::FCVT_H_W: case RISCV::FCVT_H_W_INX: diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp index e9f43b9..84bb294 100644 --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp @@ -438,18 +438,19 @@ void RISCVRegisterInfo::lowerSegmentSpillReload(MachineBasicBlock::iterator II, TypeSize VRegSize = OldLoc.getValue().divideCoefficientBy(NumRegs); Register VLENB = 0; - unsigned PreHandledNum = 0; + unsigned VLENBShift = 0; + unsigned PrevHandledNum = 0; unsigned I = 0; while (I != NumRegs) { auto [LMulHandled, RegClass, Opcode] = getSpillReloadInfo(NumRegs - I, RegEncoding, IsSpill); auto [RegNumHandled, _] = RISCVVType::decodeVLMUL(LMulHandled); bool IsLast = I + RegNumHandled == NumRegs; - if (PreHandledNum) { + if (PrevHandledNum) { Register Step; // Optimize for constant VLEN. if (auto VLEN = STI.getRealVLen()) { - int64_t Offset = *VLEN / 8 * PreHandledNum; + int64_t Offset = *VLEN / 8 * PrevHandledNum; Step = MRI.createVirtualRegister(&RISCV::GPRRegClass); STI.getInstrInfo()->movImm(MBB, II, DL, Step, Offset); } else { @@ -457,15 +458,21 @@ void RISCVRegisterInfo::lowerSegmentSpillReload(MachineBasicBlock::iterator II, VLENB = MRI.createVirtualRegister(&RISCV::GPRRegClass); BuildMI(MBB, II, DL, TII->get(RISCV::PseudoReadVLENB), VLENB); } - uint32_t ShiftAmount = Log2_32(PreHandledNum); - if (ShiftAmount == 0) - Step = VLENB; - else { - Step = MRI.createVirtualRegister(&RISCV::GPRRegClass); - BuildMI(MBB, II, DL, TII->get(RISCV::SLLI), Step) - .addReg(VLENB, getKillRegState(IsLast)) - .addImm(ShiftAmount); + uint32_t ShiftAmount = Log2_32(PrevHandledNum); + // To avoid using an extra register, we shift the VLENB register and + // remember how much it has been shifted. We can then use relative + // shifts to adjust to the desired shift amount. + if (VLENBShift > ShiftAmount) { + BuildMI(MBB, II, DL, TII->get(RISCV::SRLI), VLENB) + .addReg(VLENB, RegState::Kill) + .addImm(VLENBShift - ShiftAmount); + } else if (VLENBShift < ShiftAmount) { + BuildMI(MBB, II, DL, TII->get(RISCV::SLLI), VLENB) + .addReg(VLENB, RegState::Kill) + .addImm(ShiftAmount - VLENBShift); } + VLENBShift = ShiftAmount; + Step = VLENB; } BuildMI(MBB, II, DL, TII->get(RISCV::ADD), NewBase) @@ -489,7 +496,7 @@ void RISCVRegisterInfo::lowerSegmentSpillReload(MachineBasicBlock::iterator II, if (IsSpill) MIB.addReg(Reg, RegState::Implicit); - PreHandledNum = RegNumHandled; + PrevHandledNum = RegNumHandled; RegEncoding += RegNumHandled; I += RegNumHandled; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 3fea21e..3f0424f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -3151,6 +3151,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectInsertElt(ResVReg, ResType, I); case Intrinsic::spv_gep: return selectGEP(ResVReg, ResType, I); + case Intrinsic::spv_bitcast: { + Register OpReg = I.getOperand(2).getReg(); + SPIRVType *OpType = + OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr; + if (!GR.isBitcastCompatible(ResType, OpType)) + report_fatal_error("incompatible result and operand types in a bitcast"); + return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast); + } case Intrinsic::spv_unref_global: case Intrinsic::spv_init_global: { MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg()); diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 6e444c9..65dffc7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass { // Returns the loaded value. Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType, FixedVectorType *TargetType, Value *Source) { - assert(TargetType->getNumElements() <= SourceType->getNumElements()); LoadInst *NewLoad = B.CreateLoad(SourceType, Source); buildAssignType(B, SourceType, NewLoad); Value *AssignValue = NewLoad; if (TargetType->getElementType() != SourceType->getElementType()) { + const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout(); + [[maybe_unused]] TypeSize TargetTypeSize = + DL.getTypeSizeInBits(TargetType); + [[maybe_unused]] TypeSize SourceTypeSize = + DL.getTypeSizeInBits(SourceType); + assert(TargetTypeSize == SourceTypeSize); AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast, {TargetType, SourceType}, {NewLoad}); buildAssignType(B, TargetType, AssignValue); + return AssignValue; } + assert(TargetType->getNumElements() < SourceType->getNumElements()); SmallVector<int> Mask(/* Size= */ TargetType->getNumElements()); for (unsigned I = 0; I < TargetType->getNumElements(); ++I) Mask[I] = I; diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index f7cdfcb..db036a5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -613,8 +613,7 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, << FinalFlags << "\n"; MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI); MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2); - OrigFlagsOp = - MachineOperand::CreateImm(static_cast<unsigned>(FinalFlags)); + OrigFlagsOp = MachineOperand::CreateImm(FinalFlags); return; // Merge done, so we found a duplicate; don't add it to MAI.MS } } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index db6f2d6..d538009 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -192,31 +192,43 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, .addUse(OpReg); } -// We do instruction selections early instead of calling MIB.buildBitcast() -// generating the general op code G_BITCAST. When MachineVerifier validates -// G_BITCAST we see a check of a kind: if Source Type is equal to Destination -// Type then report error "bitcast must change the type". This doesn't take into -// account the notion of a typed pointer that is important for SPIR-V where a -// user may and should use bitcast between pointers with different pointee types -// (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast). -// It's important for correct lowering in SPIR-V, because interpretation of the -// data type is not left to instructions that utilize the pointer, but encoded -// by the pointer declaration, and the SPIRV target can and must handle the -// declaration and use of pointers that specify the type of data they point to. -// It's not feasible to improve validation of G_BITCAST using just information -// provided by low level types of source and destination. Therefore we don't -// produce G_BITCAST as the general op code with semantics different from -// OpBitcast, but rather lower to OpBitcast immediately. As for now, the only -// difference would be that CombinerHelper couldn't transform known patterns -// around G_BUILD_VECTOR. See discussion -// in https://github.com/llvm/llvm-project/pull/110270 for even more context. -static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { +// We lower G_BITCAST to OpBitcast here to avoid a MachineVerifier error. +// The verifier checks if the source and destination LLTs of a G_BITCAST are +// different, but this check is too strict for SPIR-V's typed pointers, which +// may have the same LLT but different SPIRVType (e.g. pointers to different +// pointee types). By lowering to OpBitcast here, we bypass the verifier's +// check. See discussion in https://github.com/llvm/llvm-project/pull/110270 +// for more context. +// +// We also handle the llvm.spv.bitcast intrinsic here. If the source and +// destination SPIR-V types are the same, we lower it to a COPY to enable +// further optimizations like copy propagation. +static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { SmallVector<MachineInstr *, 16> ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { + if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(2).getReg(); + SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg); + assert( + DstType && + "Expected destination SPIR-V type to have been assigned already."); + SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg); + assert(SrcType && + "Expected source SPIR-V type to have been assigned already."); + if (DstType == SrcType) { + MIB.setInsertPt(*MI.getParent(), MI); + MIB.buildCopy(DstReg, SrcReg); + ToErase.push_back(&MI); + continue; + } + } + if (MI.getOpcode() != TargetOpcode::G_BITCAST) continue; + MIB.setInsertPt(*MI.getParent(), MI); buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); @@ -237,16 +249,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, SmallVector<MachineInstr *, 10> ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { - if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && - !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) + if (!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) continue; assert(MI.getOperand(2).isReg()); MIB.setInsertPt(*MI.getParent(), MI); ToErase.push_back(&MI); - if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { - MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg()); - continue; - } Register Def = MI.getOperand(0).getReg(); Register Source = MI.getOperand(2).getReg(); Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0); @@ -1089,7 +1096,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) { removeImplicitFallthroughs(MF, MIB); insertSpirvDecorations(MF, GR, MIB); insertInlineAsm(MF, GR, ST, MIB); - selectOpBitcasts(MF, GR, MIB); + lowerBitcasts(MF, GR, MIB); return true; } diff --git a/llvm/lib/Target/X86/AsmParser/X86AsmParser.cpp b/llvm/lib/Target/X86/AsmParser/X86AsmParser.cpp index 127ee67..b7ea672 100644 --- a/llvm/lib/Target/X86/AsmParser/X86AsmParser.cpp +++ b/llvm/lib/Target/X86/AsmParser/X86AsmParser.cpp @@ -1121,7 +1121,7 @@ private: void setTypeInfo(AsmTypeInfo Type) { CurType = Type; } }; - bool Error(SMLoc L, const Twine &Msg, SMRange Range = std::nullopt, + bool Error(SMLoc L, const Twine &Msg, SMRange Range = {}, bool MatchingInlineAsm = false) { MCAsmParser &Parser = getParser(); if (MatchingInlineAsm) { @@ -4322,7 +4322,7 @@ bool X86AsmParser::matchAndEmitATTInstruction( SMLoc IDLoc, unsigned &Opcode, MCInst &Inst, OperandVector &Operands, MCStreamer &Out, uint64_t &ErrorInfo, bool MatchingInlineAsm) { X86Operand &Op = static_cast<X86Operand &>(*Operands[0]); - SMRange EmptyRange = std::nullopt; + SMRange EmptyRange; // In 16-bit mode, if data32 is specified, temporarily switch to 32-bit mode // when matching the instruction. if (ForcedDataPrefix == X86::Is32Bit) @@ -4548,7 +4548,7 @@ bool X86AsmParser::matchAndEmitIntelInstruction( SMLoc IDLoc, unsigned &Opcode, MCInst &Inst, OperandVector &Operands, MCStreamer &Out, uint64_t &ErrorInfo, bool MatchingInlineAsm) { X86Operand &Op = static_cast<X86Operand &>(*Operands[0]); - SMRange EmptyRange = std::nullopt; + SMRange EmptyRange; // Find one unsized memory operand, if present. X86Operand *UnsizedMemOp = nullptr; for (const auto &Op : Operands) { diff --git a/llvm/lib/Target/X86/AsmParser/X86Operand.h b/llvm/lib/Target/X86/AsmParser/X86Operand.h index 89ac53e..a922725 100644 --- a/llvm/lib/Target/X86/AsmParser/X86Operand.h +++ b/llvm/lib/Target/X86/AsmParser/X86Operand.h @@ -620,37 +620,6 @@ struct X86Operand final : public MCParsedAsmOperand { Inst.addOperand(MCOperand::createReg(Reg)); } - bool isTILEPair() const { - return Kind == Register && - X86MCRegisterClasses[X86::TILERegClassID].contains(getReg()); - } - - void addTILEPairOperands(MCInst &Inst, unsigned N) const { - assert(N == 1 && "Invalid number of operands!"); - MCRegister Reg = getReg(); - switch (Reg.id()) { - default: - llvm_unreachable("Invalid tile register!"); - case X86::TMM0: - case X86::TMM1: - Reg = X86::TMM0_TMM1; - break; - case X86::TMM2: - case X86::TMM3: - Reg = X86::TMM2_TMM3; - break; - case X86::TMM4: - case X86::TMM5: - Reg = X86::TMM4_TMM5; - break; - case X86::TMM6: - case X86::TMM7: - Reg = X86::TMM6_TMM7; - break; - } - Inst.addOperand(MCOperand::createReg(Reg)); - } - void addMemOperands(MCInst &Inst, unsigned N) const { assert((N == 5) && "Invalid number of operands!"); if (getMemBaseReg()) diff --git a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp index 4927b45..7d2b5eb 100644 --- a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp +++ b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp @@ -810,10 +810,6 @@ static int readModRM(struct InternalInstruction *insn) { if (index > 7) \ *valid = 0; \ return prefix##_TMM0 + index; \ - case TYPE_TMM_PAIR: \ - if (index > 7) \ - *valid = 0; \ - return prefix##_TMM0_TMM1 + (index / 2); \ case TYPE_VK: \ index &= 0xf; \ if (index > 7) \ @@ -2323,7 +2319,6 @@ static bool translateRM(MCInst &mcInst, const OperandSpecifier &operand, case TYPE_YMM: case TYPE_ZMM: case TYPE_TMM: - case TYPE_TMM_PAIR: case TYPE_VK_PAIR: case TYPE_VK: case TYPE_DEBUGREG: diff --git a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h index dc9af2c..b0aa70b 100644 --- a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h +++ b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h @@ -535,12 +535,6 @@ namespace X86Disassembler { ENTRY(TMM6) \ ENTRY(TMM7) -#define REGS_TMM_PAIRS \ - ENTRY(TMM0_TMM1) \ - ENTRY(TMM2_TMM3) \ - ENTRY(TMM4_TMM5) \ - ENTRY(TMM6_TMM7) - #define ALL_EA_BASES \ EA_BASES_16BIT \ EA_BASES_32BIT \ @@ -565,7 +559,6 @@ namespace X86Disassembler { REGS_DEBUG \ REGS_CONTROL \ REGS_TMM \ - REGS_TMM_PAIRS \ ENTRY(RIP) /// All possible values of the base field for effective-address diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp index 1c5f166..759d95e 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp @@ -467,22 +467,3 @@ void X86InstPrinterCommon::printVKPair(const MCInst *MI, unsigned OpNo, } llvm_unreachable("Unknown mask pair register name"); } - -void X86InstPrinterCommon::printTILEPair(const MCInst *MI, unsigned OpNo, - raw_ostream &OS) { - switch (MI->getOperand(OpNo).getReg()) { - case X86::TMM0_TMM1: - printRegName(OS, X86::TMM0); - return; - case X86::TMM2_TMM3: - printRegName(OS, X86::TMM2); - return; - case X86::TMM4_TMM5: - printRegName(OS, X86::TMM4); - return; - case X86::TMM6_TMM7: - printRegName(OS, X86::TMM6); - return; - } - llvm_unreachable("Unknown mask pair register name"); -} diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h index 2c9467c..cb55f2f 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h @@ -40,7 +40,6 @@ protected: const MCSubtargetInfo &STI); void printOptionalSegReg(const MCInst *MI, unsigned OpNo, raw_ostream &O); void printVKPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS); - void printTILEPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS); }; } // end namespace llvm diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td index a1fd366..9e291a6 100644 --- a/llvm/lib/Target/X86/X86.td +++ b/llvm/lib/Target/X86/X86.td @@ -274,9 +274,6 @@ def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true", def FeatureAMXMOVRS : SubtargetFeature<"amx-movrs", "HasAMXMOVRS", "true", "Support AMX-MOVRS instructions", [FeatureAMXTILE]>; -def FeatureAMXTRANSPOSE : SubtargetFeature<"amx-transpose", "HasAMXTRANSPOSE", "true", - "Support AMX amx-transpose instructions", - [FeatureAMXTILE]>; def FeatureAMXAVX512 : SubtargetFeature<"amx-avx512", "HasAMXAVX512", "true", "Support AMX-AVX512 instructions", @@ -1177,8 +1174,7 @@ def ProcessorFeatures { FeatureAMXMOVRS, FeatureAMXAVX512, FeatureAMXFP8, - FeatureAMXTF32, - FeatureAMXTRANSPOSE]; + FeatureAMXTF32]; list<SubtargetFeature> DMRFeatures = !listconcat(GNRDFeatures, DMRAdditionalFeatures); diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp index 4a9b824..e3c44c0 100644 --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -649,149 +649,6 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, MI.setDesc(TII->get(Opc)); return true; } - // TILEPAIRLOAD is just for TILEPair spill, we don't have corresponding - // AMX instruction to support it. So, split it to 2 load instructions: - // "TILEPAIRLOAD TMM0:TMM1, Base, Scale, Index, Offset, Segment" --> - // "TILELOAD TMM0, Base, Scale, Index, Offset, Segment" + - // "TILELOAD TMM1, Base, Scale, Index, Offset + TMM_SIZE, Segment" - case X86::PTILEPAIRLOAD: { - int64_t Disp = MBBI->getOperand(1 + X86::AddrDisp).getImm(); - Register TReg = MBBI->getOperand(0).getReg(); - bool DstIsDead = MBBI->getOperand(0).isDead(); - Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0); - Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1); - unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8; - - MachineInstrBuilder MIBLo = - BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD)) - .addReg(TReg0, RegState::Define | getDeadRegState(DstIsDead)); - MachineInstrBuilder MIBHi = - BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD)) - .addReg(TReg1, RegState::Define | getDeadRegState(DstIsDead)); - - for (int i = 0; i < X86::AddrNumOperands; ++i) { - MIBLo.add(MBBI->getOperand(1 + i)); - if (i == X86::AddrDisp) - MIBHi.addImm(Disp + TmmSize); - else - MIBHi.add(MBBI->getOperand(1 + i)); - } - - // Make sure the first stride reg used in first tileload is alive. - MachineOperand &Stride = - MIBLo.getInstr()->getOperand(1 + X86::AddrIndexReg); - Stride.setIsKill(false); - - // Split the memory operand, adjusting the offset and size for the halves. - MachineMemOperand *OldMMO = MBBI->memoperands().front(); - MachineFunction *MF = MBB.getParent(); - MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize); - MachineMemOperand *MMOHi = - MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize); - - MIBLo.setMemRefs(MMOLo); - MIBHi.setMemRefs(MMOHi); - - // Delete the pseudo. - MBB.erase(MBBI); - return true; - } - // Similar with TILEPAIRLOAD, TILEPAIRSTORE is just for TILEPair spill, no - // corresponding AMX instruction to support it. So, split it too: - // "TILEPAIRSTORE Base, Scale, Index, Offset, Segment, TMM0:TMM1" --> - // "TILESTORE Base, Scale, Index, Offset, Segment, TMM0" + - // "TILESTORE Base, Scale, Index, Offset + TMM_SIZE, Segment, TMM1" - case X86::PTILEPAIRSTORE: { - int64_t Disp = MBBI->getOperand(X86::AddrDisp).getImm(); - Register TReg = MBBI->getOperand(X86::AddrNumOperands).getReg(); - bool SrcIsKill = MBBI->getOperand(X86::AddrNumOperands).isKill(); - Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0); - Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1); - unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8; - - MachineInstrBuilder MIBLo = - BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED)); - MachineInstrBuilder MIBHi = - BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED)); - - for (int i = 0; i < X86::AddrNumOperands; ++i) { - MIBLo.add(MBBI->getOperand(i)); - if (i == X86::AddrDisp) - MIBHi.addImm(Disp + TmmSize); - else - MIBHi.add(MBBI->getOperand(i)); - } - MIBLo.addReg(TReg0, getKillRegState(SrcIsKill)); - MIBHi.addReg(TReg1, getKillRegState(SrcIsKill)); - - // Make sure the first stride reg used in first tilestore is alive. - MachineOperand &Stride = MIBLo.getInstr()->getOperand(X86::AddrIndexReg); - Stride.setIsKill(false); - - // Split the memory operand, adjusting the offset and size for the halves. - MachineMemOperand *OldMMO = MBBI->memoperands().front(); - MachineFunction *MF = MBB.getParent(); - MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize); - MachineMemOperand *MMOHi = - MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize); - - MIBLo.setMemRefs(MMOLo); - MIBHi.setMemRefs(MMOHi); - - // Delete the pseudo. - MBB.erase(MBBI); - return true; - } - case X86::PT2RPNTLVWZ0V: - case X86::PT2RPNTLVWZ0T1V: - case X86::PT2RPNTLVWZ1V: - case X86::PT2RPNTLVWZ1T1V: - case X86::PT2RPNTLVWZ0RSV: - case X86::PT2RPNTLVWZ0RST1V: - case X86::PT2RPNTLVWZ1RSV: - case X86::PT2RPNTLVWZ1RST1V: { - for (unsigned i = 3; i > 0; --i) - MI.removeOperand(i); - unsigned Opc; - switch (Opcode) { - case X86::PT2RPNTLVWZ0V: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0); - break; - case X86::PT2RPNTLVWZ0T1V: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1); - break; - case X86::PT2RPNTLVWZ1V: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1); - break; - case X86::PT2RPNTLVWZ1T1V: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1); - break; - case X86::PT2RPNTLVWZ0RSV: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS); - break; - case X86::PT2RPNTLVWZ0RST1V: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1); - break; - case X86::PT2RPNTLVWZ1RSV: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS); - break; - case X86::PT2RPNTLVWZ1RST1V: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1); - break; - default: - llvm_unreachable("Impossible Opcode!"); - } - MI.setDesc(TII->get(Opc)); - return true; - } - case X86::PTTRANSPOSEDV: - case X86::PTCONJTFP16V: { - for (int i = 2; i > 0; --i) - MI.removeOperand(i); - MI.setDesc(TII->get(Opcode == X86::PTTRANSPOSEDV ? X86::TTRANSPOSED - : X86::TCONJTFP16)); - return true; - } case X86::PTCMMIMFP16PSV: case X86::PTCMMRLFP16PSV: case X86::PTDPBSSDV: @@ -800,13 +657,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, case X86::PTDPBUUDV: case X86::PTDPBF16PSV: case X86::PTDPFP16PSV: - case X86::PTTDPBF16PSV: - case X86::PTTDPFP16PSV: - case X86::PTTCMMIMFP16PSV: - case X86::PTTCMMRLFP16PSV: - case X86::PTCONJTCMMIMFP16PSV: case X86::PTMMULTF32PSV: - case X86::PTTMMULTF32PSV: case X86::PTDPBF8PSV: case X86::PTDPBHF8PSV: case X86::PTDPHBF8PSV: @@ -816,6 +667,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, MI.removeOperand(i); unsigned Opc; switch (Opcode) { + // clang-format off case X86::PTCMMIMFP16PSV: Opc = X86::TCMMIMFP16PS; break; case X86::PTCMMRLFP16PSV: Opc = X86::TCMMRLFP16PS; break; case X86::PTDPBSSDV: Opc = X86::TDPBSSD; break; @@ -824,40 +676,12 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break; case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break; case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break; - case X86::PTTDPBF16PSV: - Opc = X86::TTDPBF16PS; - break; - case X86::PTTDPFP16PSV: - Opc = X86::TTDPFP16PS; - break; - case X86::PTTCMMIMFP16PSV: - Opc = X86::TTCMMIMFP16PS; - break; - case X86::PTTCMMRLFP16PSV: - Opc = X86::TTCMMRLFP16PS; - break; - case X86::PTCONJTCMMIMFP16PSV: - Opc = X86::TCONJTCMMIMFP16PS; - break; - case X86::PTMMULTF32PSV: - Opc = X86::TMMULTF32PS; - break; - case X86::PTTMMULTF32PSV: - Opc = X86::TTMMULTF32PS; - break; - case X86::PTDPBF8PSV: - Opc = X86::TDPBF8PS; - break; - case X86::PTDPBHF8PSV: - Opc = X86::TDPBHF8PS; - break; - case X86::PTDPHBF8PSV: - Opc = X86::TDPHBF8PS; - break; - case X86::PTDPHF8PSV: - Opc = X86::TDPHF8PS; - break; - + case X86::PTMMULTF32PSV: Opc = X86::TMMULTF32PS; break; + case X86::PTDPBF8PSV: Opc = X86::TDPBF8PS; break; + case X86::PTDPBHF8PSV: Opc = X86::TDPBHF8PS; break; + case X86::PTDPHBF8PSV: Opc = X86::TDPHBF8PS; break; + case X86::PTDPHF8PSV: Opc = X86::TDPHF8PS; break; + // clang-format on default: llvm_unreachable("Unexpected Opcode"); } diff --git a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp index 787b71d..06f729a 100644 --- a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp @@ -267,24 +267,16 @@ void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI, << printReg(TileReg, TRI) << '\n'); } -static unsigned getTileDefNum(MachineRegisterInfo *MRI, Register Reg) { - if (Reg.isVirtual()) { - unsigned RegClassID = MRI->getRegClass(Reg)->getID(); - if (RegClassID == X86::TILERegClassID) - return 1; - if (RegClassID == X86::TILEPAIRRegClassID) - return 2; - } else { - if (Reg >= X86::TMM0 && Reg <= X86::TMM7) - return 1; - if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) - return 2; +static bool isTileRegister(MachineRegisterInfo *MRI, Register Reg) { + if (Reg.isVirtual() && + (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)) { + return true; } - return 0; -} -static bool isTileRegister(MachineRegisterInfo *MRI, Register VirtReg) { - return getTileDefNum(MRI, VirtReg) > 0; + if (Reg >= X86::TMM0 && Reg <= X86::TMM7) + return true; + + return false; } static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) { @@ -296,7 +288,7 @@ static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) { if (!MO.isReg()) return false; - return getTileDefNum(MRI, MO.getReg()) > 0; + return isTileRegister(MRI, MO.getReg()); } static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) { @@ -636,19 +628,7 @@ bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) { else if (dominates(MBB, LastShapeMI, ColMI)) LastShapeMI = ColMI; } - unsigned TileDefNum = getTileDefNum(MRI, MI.getOperand(0).getReg()); - if (TileDefNum > 1) { - for (unsigned I = 1; I < TileDefNum; I++) { - MachineOperand *ColxMO = &MI.getOperand(2 + I); - MachineInstr *ColxMI = MRI->getVRegDef(ColxMO->getReg()); - if (ColxMI->getParent() == &MBB) { - if (!LastShapeMI) - LastShapeMI = ColxMI; - else if (dominates(MBB, LastShapeMI, ColxMI)) - LastShapeMI = ColxMI; - } - } - } + // If there is user live out of the tilecfg, spill it and reload in // before the user. Register TileReg = MI.getOperand(0).getReg(); diff --git a/llvm/lib/Target/X86/X86FastTileConfig.cpp b/llvm/lib/Target/X86/X86FastTileConfig.cpp index 11d331b..d86ae36 100644 --- a/llvm/lib/Target/X86/X86FastTileConfig.cpp +++ b/llvm/lib/Target/X86/X86FastTileConfig.cpp @@ -77,14 +77,14 @@ INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE, INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE, "Fast Tile Register Configure", false, false) -static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) { +static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) { // There is no phi instruction after register allocation. assert(MI.isPHI() == false); // The instruction must have 3 operands: tile def, row, col. // It should be AMX pseudo instruction that have shape operand. if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 || !MI.isPseudo()) - return 0; + return false; MachineOperand &MO = MI.getOperand(0); if (MO.isReg()) { @@ -93,24 +93,18 @@ static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) { // register is not rewritten yet. if (Reg.isVirtual()) { if (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID) - return 1; - if (MRI->getRegClass(Reg)->getID() == X86::TILEPAIRRegClassID) - return 2; + return true; } if (Reg >= X86::TMM0 && Reg <= X86::TMM7) - return 1; - if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) - return 2; + return true; } - return 0; + return false; } static unsigned getTMMIndex(Register Reg) { if (Reg >= X86::TMM0 && Reg <= X86::TMM7) return Reg - X86::TMM0; - if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) - return (Reg - X86::TMM0_TMM1) * 2; llvm_unreachable("Invalid Tmm Reg!"); } @@ -120,17 +114,14 @@ bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) { bool Change = false; SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos; for (MachineInstr &MI : reverse(MBB)) { - unsigned DefNum = getNumDefTiles(MRI, MI); - if (DefNum == 0 && MI.getOpcode() != X86::PLDTILECFGV) + if (!isTileDef(MRI, MI) && MI.getOpcode() != X86::PLDTILECFGV) continue; // AMX instructions that define tile register. if (MI.getOpcode() != X86::PLDTILECFGV) { MachineOperand &Row = MI.getOperand(1); unsigned TMMIdx = getTMMIndex(MI.getOperand(0).getReg()); - for (unsigned I = 0; I < DefNum; I++) { - MachineOperand &Col = MI.getOperand(2 + I); - ShapeInfos.push_back({TMMIdx + I, ShapeT(&Row, &Col)}); - } + MachineOperand &Col = MI.getOperand(2); + ShapeInfos.push_back({TMMIdx, ShapeT(&Row, &Col)}); } else { // PLDTILECFGV // Rewrite the shape information to memory. Stack slot should have // been initialized to zero in pre config. diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 4393f6e..d4418c8 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -337,23 +337,8 @@ namespace { // lowering but before ISEL. bool isAMXSDNode(SDNode *N) const { // Check if N is AMX SDNode: - // 1. check specific opcode since these carry MVT::Untyped instead of - // x86amx_type; - // 2. check result type; - // 3. check operand type; - switch (N->getOpcode()) { - default: - break; - case X86::PT2RPNTLVWZ0V: - case X86::PT2RPNTLVWZ0T1V: - case X86::PT2RPNTLVWZ1V: - case X86::PT2RPNTLVWZ1T1V: - case X86::PT2RPNTLVWZ0RSV: - case X86::PT2RPNTLVWZ0RST1V: - case X86::PT2RPNTLVWZ1RSV: - case X86::PT2RPNTLVWZ1RST1V: - return true; - } + // 1. check result type; + // 2. check operand type; for (unsigned Idx = 0, E = N->getNumValues(); Idx != E; ++Idx) { if (N->getValueType(Idx) == MVT::x86amx) return true; @@ -5398,65 +5383,6 @@ void X86DAGToDAGISel::Select(SDNode *Node) { ReplaceNode(Node, CNode); return; } - case Intrinsic::x86_t2rpntlvwz0rs: - case Intrinsic::x86_t2rpntlvwz0rst1: - case Intrinsic::x86_t2rpntlvwz1rs: - case Intrinsic::x86_t2rpntlvwz1rst1: - if (!Subtarget->hasAMXMOVRS()) - break; - [[fallthrough]]; - case Intrinsic::x86_t2rpntlvwz0: - case Intrinsic::x86_t2rpntlvwz0t1: - case Intrinsic::x86_t2rpntlvwz1: - case Intrinsic::x86_t2rpntlvwz1t1: { - if (!Subtarget->hasAMXTRANSPOSE()) - break; - auto *MFI = - CurDAG->getMachineFunction().getInfo<X86MachineFunctionInfo>(); - MFI->setAMXProgModel(AMXProgModelEnum::DirectReg); - unsigned Opc; - switch (IntNo) { - default: - llvm_unreachable("Unexpected intrinsic!"); - case Intrinsic::x86_t2rpntlvwz0: - Opc = X86::PT2RPNTLVWZ0; - break; - case Intrinsic::x86_t2rpntlvwz0t1: - Opc = X86::PT2RPNTLVWZ0T1; - break; - case Intrinsic::x86_t2rpntlvwz1: - Opc = X86::PT2RPNTLVWZ1; - break; - case Intrinsic::x86_t2rpntlvwz1t1: - Opc = X86::PT2RPNTLVWZ1T1; - break; - case Intrinsic::x86_t2rpntlvwz0rs: - Opc = X86::PT2RPNTLVWZ0RS; - break; - case Intrinsic::x86_t2rpntlvwz0rst1: - Opc = X86::PT2RPNTLVWZ0RST1; - break; - case Intrinsic::x86_t2rpntlvwz1rs: - Opc = X86::PT2RPNTLVWZ1RS; - break; - case Intrinsic::x86_t2rpntlvwz1rst1: - Opc = X86::PT2RPNTLVWZ1RST1; - break; - } - // FIXME: Match displacement and scale. - unsigned TIndex = Node->getConstantOperandVal(2); - SDValue TReg = getI8Imm(TIndex, dl); - SDValue Base = Node->getOperand(3); - SDValue Scale = getI8Imm(1, dl); - SDValue Index = Node->getOperand(4); - SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32); - SDValue Segment = CurDAG->getRegister(0, MVT::i16); - SDValue Chain = Node->getOperand(0); - SDValue Ops[] = {TReg, Base, Scale, Index, Disp, Segment, Chain}; - MachineSDNode *CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops); - ReplaceNode(Node, CNode); - return; - } } break; } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 624cff2..007074c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -27946,67 +27946,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC, Operation.getValue(1)); } - case Intrinsic::x86_t2rpntlvwz0rs_internal: - case Intrinsic::x86_t2rpntlvwz0rst1_internal: - case Intrinsic::x86_t2rpntlvwz1rs_internal: - case Intrinsic::x86_t2rpntlvwz1rst1_internal: - case Intrinsic::x86_t2rpntlvwz0_internal: - case Intrinsic::x86_t2rpntlvwz0t1_internal: - case Intrinsic::x86_t2rpntlvwz1_internal: - case Intrinsic::x86_t2rpntlvwz1t1_internal: { - auto *X86MFI = DAG.getMachineFunction().getInfo<X86MachineFunctionInfo>(); - X86MFI->setAMXProgModel(AMXProgModelEnum::ManagedRA); - unsigned IntNo = Op.getConstantOperandVal(1); - unsigned Opc = 0; - switch (IntNo) { - default: - llvm_unreachable("Unexpected intrinsic!"); - case Intrinsic::x86_t2rpntlvwz0_internal: - Opc = X86::PT2RPNTLVWZ0V; - break; - case Intrinsic::x86_t2rpntlvwz0t1_internal: - Opc = X86::PT2RPNTLVWZ0T1V; - break; - case Intrinsic::x86_t2rpntlvwz1_internal: - Opc = X86::PT2RPNTLVWZ1V; - break; - case Intrinsic::x86_t2rpntlvwz1t1_internal: - Opc = X86::PT2RPNTLVWZ1T1V; - break; - case Intrinsic::x86_t2rpntlvwz0rs_internal: - Opc = X86::PT2RPNTLVWZ0RSV; - break; - case Intrinsic::x86_t2rpntlvwz0rst1_internal: - Opc = X86::PT2RPNTLVWZ0RST1V; - break; - case Intrinsic::x86_t2rpntlvwz1rs_internal: - Opc = X86::PT2RPNTLVWZ1RSV; - break; - case Intrinsic::x86_t2rpntlvwz1rst1_internal: - Opc = X86::PT2RPNTLVWZ1RST1V; - break; - } - - SDLoc DL(Op); - SDVTList VTs = DAG.getVTList(MVT::Untyped, MVT::Other); - - SDValue Ops[] = {Op.getOperand(2), // Row - Op.getOperand(3), // Col0 - Op.getOperand(4), // Col1 - Op.getOperand(5), // Base - DAG.getTargetConstant(1, DL, MVT::i8), // Scale - Op.getOperand(6), // Index - DAG.getTargetConstant(0, DL, MVT::i32), // Disp - DAG.getRegister(0, MVT::i16), // Segment - Op.getOperand(0)}; // Chain - - MachineSDNode *Res = DAG.getMachineNode(Opc, DL, VTs, Ops); - SDValue Res0 = DAG.getTargetExtractSubreg(X86::sub_t0, DL, MVT::x86amx, - SDValue(Res, 0)); - SDValue Res1 = DAG.getTargetExtractSubreg(X86::sub_t1, DL, MVT::x86amx, - SDValue(Res, 0)); - return DAG.getMergeValues({Res0, Res1, SDValue(Res, 1)}, DL); - } case Intrinsic::x86_atomic_bts_rm: case Intrinsic::x86_atomic_btc_rm: case Intrinsic::x86_atomic_btr_rm: { @@ -37745,10 +37684,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, assert (Imm < 8 && "Illegal tmm index"); return X86::TMM0 + Imm; }; - auto TMMImmToTMMPair = [](unsigned Imm) { - assert(Imm < 8 && "Illegal tmm pair index."); - return X86::TMM0_TMM1 + Imm / 2; - }; switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected instr type to insert"); @@ -38129,53 +38064,25 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::PTDPBHF8PS: case X86::PTDPHBF8PS: case X86::PTDPHF8PS: - case X86::PTTDPBF16PS: - case X86::PTTDPFP16PS: - case X86::PTTCMMIMFP16PS: - case X86::PTTCMMRLFP16PS: - case X86::PTCONJTCMMIMFP16PS: - case X86::PTMMULTF32PS: - case X86::PTTMMULTF32PS: { + case X86::PTMMULTF32PS: { unsigned Opc; switch (MI.getOpcode()) { default: llvm_unreachable("illegal opcode!"); + // clang-format off case X86::PTDPBSSD: Opc = X86::TDPBSSD; break; case X86::PTDPBSUD: Opc = X86::TDPBSUD; break; case X86::PTDPBUSD: Opc = X86::TDPBUSD; break; case X86::PTDPBUUD: Opc = X86::TDPBUUD; break; case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break; case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break; - case X86::PTCMMIMFP16PS: - Opc = X86::TCMMIMFP16PS; - break; - case X86::PTCMMRLFP16PS: - Opc = X86::TCMMRLFP16PS; - break; + case X86::PTCMMIMFP16PS: Opc = X86::TCMMIMFP16PS; break; + case X86::PTCMMRLFP16PS: Opc = X86::TCMMRLFP16PS; break; case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break; case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break; case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break; case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break; - case X86::PTTDPBF16PS: - Opc = X86::TTDPBF16PS; - break; - case X86::PTTDPFP16PS: - Opc = X86::TTDPFP16PS; - break; - case X86::PTTCMMIMFP16PS: - Opc = X86::TTCMMIMFP16PS; - break; - case X86::PTTCMMRLFP16PS: - Opc = X86::TTCMMRLFP16PS; - break; - case X86::PTCONJTCMMIMFP16PS: - Opc = X86::TCONJTCMMIMFP16PS; - break; - case X86::PTMMULTF32PS: - Opc = X86::TMMULTF32PS; - break; - case X86::PTTMMULTF32PS: - Opc = X86::TTMMULTF32PS; - break; + case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break; + // clang-format on } MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc)); @@ -38246,70 +38153,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); // The pseudo is gone now. return BB; } - case X86::PT2RPNTLVWZ0: - case X86::PT2RPNTLVWZ0T1: - case X86::PT2RPNTLVWZ1: - case X86::PT2RPNTLVWZ1T1: - case X86::PT2RPNTLVWZ0RS: - case X86::PT2RPNTLVWZ0RST1: - case X86::PT2RPNTLVWZ1RS: - case X86::PT2RPNTLVWZ1RST1: { - const DebugLoc &DL = MI.getDebugLoc(); - unsigned Opc; -#define GET_EGPR_IF_ENABLED(OPC) (Subtarget.hasEGPR() ? OPC##_EVEX : OPC) - switch (MI.getOpcode()) { - default: - llvm_unreachable("Unexpected instruction!"); - case X86::PT2RPNTLVWZ0: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0); - break; - case X86::PT2RPNTLVWZ0T1: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1); - break; - case X86::PT2RPNTLVWZ1: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1); - break; - case X86::PT2RPNTLVWZ1T1: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1); - break; - case X86::PT2RPNTLVWZ0RS: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS); - break; - case X86::PT2RPNTLVWZ0RST1: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1); - break; - case X86::PT2RPNTLVWZ1RS: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS); - break; - case X86::PT2RPNTLVWZ1RST1: - Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1); - break; - } -#undef GET_EGPR_IF_ENABLED - MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); - MIB.addReg(TMMImmToTMMPair(MI.getOperand(0).getImm()), RegState::Define); - - MIB.add(MI.getOperand(1)); // base - MIB.add(MI.getOperand(2)); // scale - MIB.add(MI.getOperand(3)); // index - MIB.add(MI.getOperand(4)); // displacement - MIB.add(MI.getOperand(5)); // segment - MI.eraseFromParent(); // The pseudo is gone now. - return BB; - } - case X86::PTTRANSPOSED: - case X86::PTCONJTFP16: { - const DebugLoc &DL = MI.getDebugLoc(); - unsigned Opc = MI.getOpcode() == X86::PTTRANSPOSED ? X86::TTRANSPOSED - : X86::TCONJTFP16; - - MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); - - MI.eraseFromParent(); // The pseudo is gone now. - return BB; - } case X86::PTCVTROWPS2BF16Hrri: case X86::PTCVTROWPS2BF16Lrri: case X86::PTCVTROWPS2PHHrri: @@ -48778,10 +48621,9 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC, SDValue BC0 = peekThroughBitcasts(Op0); if (BC0.getOpcode() == X86ISD::PCMPEQ && ISD::isBuildVectorAllZeros(BC0.getOperand(1).getNode())) { - SDLoc DL(EFLAGS); CC = (CC == X86::COND_B ? X86::COND_E : X86::COND_NE); - SDValue X = DAG.getBitcast(OpVT, BC0.getOperand(0)); - return DAG.getNode(EFLAGS.getOpcode(), DL, VT, X, X); + SDValue X = DAG.getBitcast(OpVT, DAG.getFreeze(BC0.getOperand(0))); + return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, X, X); } } } @@ -48837,7 +48679,7 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC, MVT FloatSVT = MVT::getFloatingPointVT(EltBits); MVT FloatVT = MVT::getVectorVT(FloatSVT, OpVT.getSizeInBits() / EltBits); - Res = DAG.getBitcast(FloatVT, Res); + Res = DAG.getBitcast(FloatVT, DAG.getFreeze(Res)); return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Res, Res); } else if (EltBits == 16) { MVT MovmskVT = BCVT.is128BitVector() ? MVT::v16i8 : MVT::v32i8; @@ -48856,8 +48698,30 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC, } // TESTZ(X,-1) == TESTZ(X,X) - if (ISD::isBuildVectorAllOnes(Op1.getNode())) + if (ISD::isBuildVectorAllOnes(Op1.getNode())) { + Op0 = DAG.getFreeze(Op0); return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op0, Op0); + } + + // Attempt to convert PTESTZ(X,SIGNMASK) -> VTESTPD/PSZ(X,X) on AVX targets. + if (EFLAGS.getOpcode() == X86ISD::PTEST && Subtarget.hasAVX()) { + KnownBits KnownOp1 = DAG.computeKnownBits(Op1); + assert(KnownOp1.getBitWidth() == 64 && + "Illegal PTEST vector element width"); + if (KnownOp1.isConstant()) { + const APInt &Mask = KnownOp1.getConstant(); + if (Mask.isSignMask()) { + MVT FpVT = MVT::getVectorVT(MVT::f64, OpVT.getSizeInBits() / 64); + Op0 = DAG.getBitcast(FpVT, DAG.getFreeze(Op0)); + return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Op0, Op0); + } + if (Mask.isSplat(32) && Mask.trunc(32).isSignMask()) { + MVT FpVT = MVT::getVectorVT(MVT::f32, OpVT.getSizeInBits() / 32); + Op0 = DAG.getBitcast(FpVT, DAG.getFreeze(Op0)); + return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Op0, Op0); + } + } + } // TESTZ(OR(LO(X),HI(X)),OR(LO(Y),HI(Y))) -> TESTZ(X,Y) // TODO: Add COND_NE handling? @@ -53480,6 +53344,105 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, return SDValue(); } +// Look for a RMW operation that only touches one bit of a larger than legal +// type and fold it to a BTC/BTR/BTS or bit insertion pattern acting on a single +// i32 sub value. +static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + using namespace SDPatternMatch; + + // Only handle normal stores and its chain was a matching normal load. + auto *Ld = dyn_cast<LoadSDNode>(St->getChain()); + if (!ISD::isNormalStore(St) || !St->isSimple() || !Ld || + !ISD::isNormalLoad(Ld) || !Ld->isSimple() || + Ld->getBasePtr() != St->getBasePtr() || + Ld->getOffset() != St->getOffset()) + return SDValue(); + + SDValue LoadVal(Ld, 0); + SDValue StoredVal = St->getValue(); + EVT VT = StoredVal.getValueType(); + + // Only narrow larger than legal scalar integers. + if (!VT.isScalarInteger() || + VT.getSizeInBits() <= (Subtarget.is64Bit() ? 64 : 32)) + return SDValue(); + + // BTR: X & ~(1 << ShAmt) + // BTS: X | (1 << ShAmt) + // BTC: X ^ (1 << ShAmt) + // + // BitInsert: (X & ~(1 << ShAmt)) | (InsertBit << ShAmt) + SDValue InsertBit, ShAmt; + if (!StoredVal.hasOneUse() || + !(sd_match(StoredVal, m_And(m_Specific(LoadVal), + m_Not(m_Shl(m_One(), m_Value(ShAmt))))) || + sd_match(StoredVal, + m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) || + sd_match(StoredVal, + m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) || + sd_match(StoredVal, + m_Or(m_And(m_Specific(LoadVal), + m_Not(m_Shl(m_One(), m_Value(ShAmt)))), + m_Shl(m_Value(InsertBit), m_Deferred(ShAmt)))))) + return SDValue(); + + // Ensure the shift amount is in bounds. + KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); + if (KnownAmt.getMaxValue().uge(VT.getSizeInBits())) + return SDValue(); + + // If we're inserting a bit then it must be the LSB. + if (InsertBit) { + KnownBits KnownInsert = DAG.computeKnownBits(InsertBit); + if (KnownInsert.countMinLeadingZeros() < (VT.getSizeInBits() - 1)) + return SDValue(); + } + + // Split the shift into an alignment shift that moves the active i32 block to + // the bottom bits for truncation and a modulo shift that can act on the i32. + EVT AmtVT = ShAmt.getValueType(); + SDValue AlignAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, + DAG.getSignedConstant(-32LL, DL, AmtVT)); + SDValue ModuloAmt = + DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT)); + ModuloAmt = DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8); + + // Compute the byte offset for the i32 block that is changed by the RMW. + // combineTruncate will adjust the load for us in a similar way. + EVT PtrVT = St->getBasePtr().getValueType(); + SDValue PtrBitOfs = DAG.getZExtOrTrunc(AlignAmt, DL, PtrVT); + SDValue PtrByteOfs = DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs, + DAG.getShiftAmountConstant(3, PtrVT, DL)); + SDValue NewPtr = DAG.getMemBasePlusOffset(St->getBasePtr(), PtrByteOfs, DL, + SDNodeFlags::NoUnsignedWrap); + + // Reconstruct the BTC/BTR/BTS pattern for the i32 block and store. + SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt); + X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); + + SDValue Mask = DAG.getNode(ISD::SHL, DL, MVT::i32, + DAG.getConstant(1, DL, MVT::i32), ModuloAmt); + + SDValue Res; + if (InsertBit) { + SDValue BitMask = + DAG.getNode(ISD::SHL, DL, MVT::i32, + DAG.getZExtOrTrunc(InsertBit, DL, MVT::i32), ModuloAmt); + Res = + DAG.getNode(ISD::AND, DL, MVT::i32, X, DAG.getNOT(DL, Mask, MVT::i32)); + Res = DAG.getNode(ISD::OR, DL, MVT::i32, Res, BitMask); + } else { + if (StoredVal.getOpcode() == ISD::AND) + Mask = DAG.getNOT(DL, Mask, MVT::i32); + Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask); + } + + return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(), + Align(), St->getMemOperand()->getFlags()); +} + static SDValue combineStore(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -53706,6 +53669,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, } } + if (SDValue R = narrowBitOpRMW(St, dl, DAG, Subtarget)) + return R; + // Convert store(cmov(load(p), x, CC), p) to cstore(x, p, CC) // store(cmov(x, load(p), CC), p) to cstore(x, p, InvertCC) if ((VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) && @@ -54493,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG, const X86Subtarget &Subtarget, const SDLoc &DL) { + using namespace SDPatternMatch; if (!VT.isVector() || !Subtarget.hasSSSE3()) return SDValue(); @@ -54502,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG, return SDValue(); SDValue SSatVal = detectSSatPattern(In, VT); - if (!SSatVal || SSatVal.getOpcode() != ISD::ADD) - return SDValue(); - - // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs - // of multiplies from even/odd elements. - SDValue N0 = SSatVal.getOperand(0); - SDValue N1 = SSatVal.getOperand(1); - - if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) + if (!SSatVal) return SDValue(); - SDValue N00 = N0.getOperand(0); - SDValue N01 = N0.getOperand(1); - SDValue N10 = N1.getOperand(0); - SDValue N11 = N1.getOperand(1); - + // See if this is a signed saturation of an ADD, adding pairs of multiplies + // from even/odd elements, from zero_extend/sign_extend operands. + // // TODO: Handle constant vectors and use knownbits/computenumsignbits? - // Canonicalize zero_extend to LHS. - if (N01.getOpcode() == ISD::ZERO_EXTEND) - std::swap(N00, N01); - if (N11.getOpcode() == ISD::ZERO_EXTEND) - std::swap(N10, N11); - - // Ensure we have a zero_extend and a sign_extend. - if (N00.getOpcode() != ISD::ZERO_EXTEND || - N01.getOpcode() != ISD::SIGN_EXTEND || - N10.getOpcode() != ISD::ZERO_EXTEND || - N11.getOpcode() != ISD::SIGN_EXTEND) + SDValue N00, N01, N10, N11; + if (!sd_match(SSatVal, + m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))), + m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11)))))) return SDValue(); - // Peek through the extends. - N00 = N00.getOperand(0); - N01 = N01.getOperand(0); - N10 = N10.getOperand(0); - N11 = N11.getOperand(0); - // Ensure the extend is from vXi8. if (N00.getValueType().getVectorElementType() != MVT::i8 || N01.getValueType().getVectorElementType() != MVT::i8 || @@ -54660,8 +54604,9 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, // truncation, see if we can convert the shift into a pointer offset instead. // Limit this to normal (non-ext) scalar integer loads. if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL && - Src.hasOneUse() && Src.getOperand(0).hasOneUse() && - ISD::isNormalLoad(Src.getOperand(0).getNode())) { + Src.hasOneUse() && ISD::isNormalLoad(Src.getOperand(0).getNode()) && + (Src.getOperand(0).hasOneUse() || + !DAG.getTargetLoweringInfo().isOperationLegal(ISD::LOAD, SrcVT))) { auto *Ld = cast<LoadSDNode>(Src.getOperand(0)); if (Ld->isSimple() && VT.isByteSized() && isPowerOf2_64(VT.getSizeInBits())) { @@ -54669,9 +54614,11 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); // Check the shift amount is byte aligned. // Check the truncation doesn't use any shifted in (zero) top bits. + // Check the shift amount doesn't depend on the original load. if (KnownAmt.countMinTrailingZeros() >= 3 && KnownAmt.getMaxValue().ule(SrcVT.getSizeInBits() - - VT.getSizeInBits())) { + VT.getSizeInBits()) && + !Ld->isPredecessorOf(ShAmt.getNode())) { EVT PtrVT = Ld->getBasePtr().getValueType(); SDValue PtrBitOfs = DAG.getZExtOrTrunc(ShAmt, DL, PtrVT); SDValue PtrByteOfs = @@ -56459,6 +56406,7 @@ static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC, static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + using namespace SDPatternMatch; const ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get(); const SDValue LHS = N->getOperand(0); const SDValue RHS = N->getOperand(1); @@ -56517,6 +56465,37 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, if (SDValue AndN = MatchAndCmpEq(RHS, LHS)) return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC); + // If we're performing a bit test on a larger than legal type, attempt + // to (aligned) shift down the value to the bottom 32-bits and then + // perform the bittest on the i32 value. + // ICMP_ZERO(AND(X,SHL(1,IDX))) + // --> ICMP_ZERO(AND(TRUNC(SRL(X,AND(IDX,-32))),SHL(1,AND(IDX,31)))) + if (isNullConstant(RHS) && + OpVT.getScalarSizeInBits() > (Subtarget.is64Bit() ? 64 : 32)) { + SDValue X, ShAmt; + if (sd_match(LHS, m_OneUse(m_And(m_Value(X), + m_Shl(m_One(), m_Value(ShAmt)))))) { + // Only attempt this if the shift amount is known to be in bounds. + KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); + if (KnownAmt.getMaxValue().ult(OpVT.getScalarSizeInBits())) { + EVT AmtVT = ShAmt.getValueType(); + SDValue AlignAmt = + DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, + DAG.getSignedConstant(-32LL, DL, AmtVT)); + SDValue ModuloAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, + DAG.getConstant(31, DL, AmtVT)); + SDValue Mask = DAG.getNode( + ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32), + DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8)); + X = DAG.getNode(ISD::SRL, DL, OpVT, X, AlignAmt); + X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); + X = DAG.getNode(ISD::AND, DL, MVT::i32, X, Mask); + return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, MVT::i32), + CC); + } + } + } + // cmpeq(trunc(x),C) --> cmpeq(x,C) // cmpne(trunc(x),C) --> cmpne(x,C) // iff x upper bits are zero. diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td index 69a5115..522782a 100644 --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -338,188 +338,6 @@ let Predicates = [HasAMXFP8, In64BitMode] in { } } -let Predicates = [HasAMXTILE, In64BitMode], isPseudo = true, SchedRW = [WriteSystem] in { - let mayStore = 1 in - def PTILEPAIRSTORE : PseudoI<(outs), (ins opaquemem:$src1, TILEPair:$src2), []>; - let mayLoad = 1 in - def PTILEPAIRLOAD : PseudoI<(outs TILEPair:$dst), (ins opaquemem:$src), []>; -} - -multiclass T2RPNTLVW_Base<bits<8> op1, bits<8> op2, string rs, string suffix> { - def Z0#rs#suffix : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), - "t2rpntlvwz0" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PS; - def Z0#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), - "t2rpntlvwz0" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PS; - def Z1#rs#suffix : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), - "t2rpntlvwz1" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PD; - def Z1#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), - "t2rpntlvwz1" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PD; -} - -let Predicates = [HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in - defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "">, T8, VEX; - -let Predicates = [HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in - defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "_EVEX">, T8, EVEX, NoCD8; - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in - defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "">, T_MAP5, VEX; - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in - defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "_EVEX">, T_MAP5, EVEX, NoCD8; - -let Predicates = [HasAMXTRANSPOSE, In64BitMode] in { - let SchedRW = [WriteSystem] in { - def TTRANSPOSED : I<0x5f, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), - "ttransposed\t{$src, $dst|$dst, $src}", []>, VEX, T8, XS; - let isPseudo = true in { - def PT2RPNTLVWZ0V : PseudoI<(outs TILEPair:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), - []>; - def PT2RPNTLVWZ0T1V : PseudoI<(outs TILEPair:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), - []>; - def PT2RPNTLVWZ1V : PseudoI<(outs TILEPair:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), - []>; - def PT2RPNTLVWZ1T1V : PseudoI<(outs TILEPair:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), - []>; - } - - def PTTRANSPOSEDV : PseudoI<(outs TILE:$dst), - (ins GR16:$src1, GR16:$src2, TILE:$src), - [(set TILE: $dst, - (int_x86_ttransposed_internal GR16:$src1, GR16:$src2, - TILE:$src))]>; - - let usesCustomInserter = 1 in { - def PT2RPNTLVWZ0 : PseudoI<(outs), (ins u8imm:$dst, - sibmem:$src1), []>; - def PT2RPNTLVWZ0T1 : PseudoI<(outs), (ins u8imm:$dst, - sibmem:$src1), []>; - def PT2RPNTLVWZ1 : PseudoI<(outs), (ins u8imm:$dst, - sibmem:$src1), []>; - def PT2RPNTLVWZ1T1 : PseudoI<(outs), (ins u8imm:$dst, - sibmem:$src1), []>; - def PTTRANSPOSED : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), - [(int_x86_ttransposed timm:$dst, timm:$src)]>; - } - } -} // HasAMXTILE, HasAMXTRANSPOSE - -let Predicates = [HasAMXBF16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { - let Constraints = "$src1 = $dst" in - def TTDPBF16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), - (ins TILE:$src1, TILE:$src2, TILE:$src3), - "ttdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", - []>, VEX, VVVV, T8,XS; - let Constraints = "$src4 = $dst" in - def PTTDPBF16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, - GR16:$src2, GR16:$src3, TILE:$src4, - TILE:$src5, TILE:$src6), - [(set TILE: $dst, - (int_x86_ttdpbf16ps_internal GR16:$src1, GR16:$src2, - GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; - let usesCustomInserter = 1 in - def PTTDPBF16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), - [(int_x86_ttdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>; -} - -let Predicates = [HasAMXFP16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { - let Constraints = "$src1 = $dst" in - def TTDPFP16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), - (ins TILE:$src1, TILE:$src2, TILE:$src3), - "ttdpfp16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", - []>, VEX, VVVV, T8,XD; - let Constraints = "$src4 = $dst" in - def PTTDPFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, - GR16:$src2, GR16:$src3, TILE:$src4, - TILE:$src5, TILE:$src6), - [(set TILE: $dst, - (int_x86_ttdpfp16ps_internal GR16:$src1, GR16:$src2, - GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; - let usesCustomInserter = 1 in - def PTTDPFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), - [(int_x86_ttdpfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -} - -let Predicates = [HasAMXCOMPLEX, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { - let Constraints = "$src1 = $dst" in { - def TTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), - (ins TILE:$src1, TILE:$src2, TILE:$src3), - "ttcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", - []>, VEX, VVVV, T8,XD; - def TTCMMRLFP16PS: I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), - (ins TILE:$src1, TILE:$src2, TILE:$src3), - "ttcmmrlfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", - []>, VEX, VVVV, T8,XS; - def TCONJTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), - (ins TILE:$src1, TILE:$src2, TILE:$src3), - "tconjtcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", - []>, VEX, VVVV, WIG, T8,PS; - } - def TCONJTFP16 : I<0x6b, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), - "tconjtfp16\t{$src, $dst|$dst, $src}", []>, VEX, T8,PD; - - let Constraints = "$src4 = $dst" in { - def PTTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, - GR16:$src2, GR16:$src3, TILE:$src4, - TILE:$src5, TILE:$src6), - [(set TILE: $dst, - (int_x86_ttcmmimfp16ps_internal GR16:$src1, GR16:$src2, - GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; - def PTTCMMRLFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, - GR16:$src2, GR16:$src3, TILE:$src4, - TILE:$src5, TILE:$src6), - [(set TILE: $dst, - (int_x86_ttcmmrlfp16ps_internal GR16:$src1, GR16:$src2, - GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; - def PTCONJTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, - GR16:$src2, GR16:$src3, TILE:$src4, - TILE:$src5, TILE:$src6), - [(set TILE: $dst, - (int_x86_tconjtcmmimfp16ps_internal GR16:$src1, GR16:$src2, - GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; - } - def PTCONJTFP16V : PseudoI<(outs TILE:$dst), (ins GR16:$src1, GR16:$src2, TILE:$src3), - [(set TILE: $dst, (int_x86_tconjtfp16_internal GR16:$src1, GR16:$src2, TILE:$src3))]>; - - let usesCustomInserter = 1 in { - def PTTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), - [(int_x86_ttcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; - def PTTCMMRLFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), - [(int_x86_ttcmmrlfp16ps timm:$src1, timm:$src2, timm:$src3)]>; - def PTCONJTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), - [(int_x86_tconjtcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; - def PTCONJTFP16 : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), - [(int_x86_tconjtfp16 timm:$dst, timm:$src)]>; - } -} - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { - let isPseudo = true in { - def PT2RPNTLVWZ0RSV : PseudoI<(outs TILEPair:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), - []>; - def PT2RPNTLVWZ0RST1V : PseudoI<(outs TILEPair:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), - []>; - def PT2RPNTLVWZ1RSV : PseudoI<(outs TILEPair:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), - []>; - def PT2RPNTLVWZ1RST1V : PseudoI<(outs TILEPair:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), - []>; - } - let usesCustomInserter = 1 in { - def PT2RPNTLVWZ0RS : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; - def PT2RPNTLVWZ0RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; - def PT2RPNTLVWZ1RS : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; - def PT2RPNTLVWZ1RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; - } -} // HasAMXMOVRS, HasAMXTRANSPOSE - multiclass TILELOADDRS_Base<string suffix> { def suffix : I<0x4a, MRMSrcMemFSIB, (outs TILE:$dst), (ins sibmem:$src1), "tileloaddrs\t{$src1, $dst|$dst, $src1}", []>, T8, XD; @@ -721,29 +539,3 @@ let Predicates = [HasAMXTF32, In64BitMode] in { } } // SchedRW = [WriteSystem] } // HasAMXTF32 - -let Predicates = [HasAMXTF32, HasAMXTRANSPOSE, In64BitMode] in { - let SchedRW = [WriteSystem] in { - let Constraints = "$src1 = $dst" in { - def TTMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst), - (ins TILE:$src1, TILE:$src2, TILE:$src3), - "ttmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", - []>, VEX, VVVV, T8, PS; - } - let Constraints = "$src4 = $dst" in { - def PTTMMULTF32PSV : PseudoI<(outs TILE:$dst), - (ins GR16:$src1, GR16:$src2, GR16:$src3, - TILE:$src4, TILE:$src5, TILE:$src6), - [(set TILE:$dst, - (int_x86_ttmmultf32ps_internal GR16:$src1, - GR16:$src2, GR16:$src3, TILE:$src4, - TILE:$src5, TILE:$src6))]>; - } - let usesCustomInserter = 1 in { - def PTTMMULTF32PS : PseudoI<(outs), - (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), - [(int_x86_ttmmultf32ps timm:$src1, timm:$src2, - timm:$src3)]>; - } - } // SchedRW = [WriteSystem] -} // HasAMXTF32, HasAMXTRANSPOSE diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index 5c23f91..6b2a7a4 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -4544,11 +4544,6 @@ static unsigned getLoadStoreRegOpcode(Register Reg, return Load ? GET_EGPR_IF_ENABLED(X86::TILELOADD) : GET_EGPR_IF_ENABLED(X86::TILESTORED); #undef GET_EGPR_IF_ENABLED - case 2048: - assert(X86::TILEPAIRRegClass.hasSubClassEq(RC) && - "Unknown 2048-byte regclass"); - assert(STI.hasAMXTILE() && "Using 2048-bit register requires AMX-TILE"); - return Load ? X86::PTILEPAIRLOAD : X86::PTILEPAIRSTORE; } } @@ -4743,8 +4738,6 @@ static bool isAMXOpcode(unsigned Opc) { case X86::TILESTORED: case X86::TILELOADD_EVEX: case X86::TILESTORED_EVEX: - case X86::PTILEPAIRLOAD: - case X86::PTILEPAIRSTORE: return true; } } @@ -4757,8 +4750,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB, default: llvm_unreachable("Unexpected special opcode!"); case X86::TILESTORED: - case X86::TILESTORED_EVEX: - case X86::PTILEPAIRSTORE: { + case X86::TILESTORED_EVEX: { // tilestored %tmm, (%sp, %idx) MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo(); Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); @@ -4772,8 +4764,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB, break; } case X86::TILELOADD: - case X86::TILELOADD_EVEX: - case X86::PTILEPAIRLOAD: { + case X86::TILELOADD_EVEX: { // tileloadd (%sp, %idx), %tmm MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo(); Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); diff --git a/llvm/lib/Target/X86/X86InstrOperands.td b/llvm/lib/Target/X86/X86InstrOperands.td index 5207eca..6ba07f7 100644 --- a/llvm/lib/Target/X86/X86InstrOperands.td +++ b/llvm/lib/Target/X86/X86InstrOperands.td @@ -536,10 +536,3 @@ def VK8Pair : RegisterOperand<VK8PAIR, "printVKPair"> { def VK16Pair : RegisterOperand<VK16PAIR, "printVKPair"> { let ParserMatchClass = VK16PairAsmOperand; } - -let RenderMethod = "addTILEPairOperands" in - def TILEPairAsmOperand : AsmOperandClass { let Name = "TILEPair"; } - -def TILEPair : RegisterOperand<TILEPAIR, "printTILEPair"> { - let ParserMatchClass = TILEPairAsmOperand; -} diff --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td index c20bb05..98104a6f 100644 --- a/llvm/lib/Target/X86/X86InstrPredicates.td +++ b/llvm/lib/Target/X86/X86InstrPredicates.td @@ -183,7 +183,6 @@ def HasAMXINT8 : Predicate<"Subtarget->hasAMXINT8()">; def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">; def HasAMXFP8 : Predicate<"Subtarget->hasAMXFP8()">; def HasAMXMOVRS : Predicate<"Subtarget->hasAMXMOVRS()">; -def HasAMXTRANSPOSE : Predicate<"Subtarget->hasAMXTRANSPOSE()">; def HasAMXAVX512 : Predicate<"Subtarget->hasAMXAVX512()">; def HasAMXTF32 : Predicate<"Subtarget->hasAMXTF32()">; def HasUINTR : Predicate<"Subtarget->hasUINTR()">; diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 8ffd454..2fc5d38 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -74,22 +74,6 @@ static bool isAMXCast(Instruction *II) { match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); } -// Some instructions may return more than one tiles. -// e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal -static unsigned getNumDefTiles(IntrinsicInst *II) { - Type *Ty = II->getType(); - if (Ty->isX86_AMXTy()) - return 1; - - unsigned Num = 0; - for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) { - Type *STy = Ty->getContainedType(i); - if (STy->isX86_AMXTy()) - Num++; - } - return Num; -} - static bool isAMXIntrinsic(Value *I) { auto *II = dyn_cast<IntrinsicInst>(I); if (!II) @@ -98,7 +82,7 @@ static bool isAMXIntrinsic(Value *I) { return false; // Check if return type or parameter is x86_amx. If it is x86_amx // the intrinsic must be x86 amx intrinsics. - if (getNumDefTiles(II) > 0) + if (II->getType()->isX86_AMXTy()) return true; for (Value *V : II->args()) { if (V->getType()->isX86_AMXTy()) @@ -137,27 +121,7 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { llvm_unreachable("No terminator in the entry block!"); } -class ShapeCalculator { -private: - const TargetMachine *TM = nullptr; - - // In AMX intrinsics we let Shape = {Row, Col}, but the - // RealCol = Col / ElementSize. We may use the RealCol - // as a new Row for other new created AMX intrinsics. - std::map<Value *, Value *> Col2Row, Row2Col; - -public: - ShapeCalculator(const TargetMachine *TargetM) : TM(TargetM) {} - std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); - std::pair<Value *, Value *> getShape(PHINode *Phi); - Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); - Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity); -}; - -Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V, - unsigned Granularity) { - if (auto It = Col2Row.find(V); It != Col2Row.end()) - return It->second; +static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) { IRBuilder<> Builder(II); Value *RealRow = nullptr; if (isa<ConstantInt>(V)) @@ -186,47 +150,16 @@ Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V, getFirstNonAllocaInTheEntryBlock(*II->getFunction())); RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity)); } - Col2Row[V] = RealRow; return RealRow; } -Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V, - unsigned Granularity) { - if (auto It = Row2Col.find(V); It != Row2Col.end()) - return It->second; - IRBuilder<> Builder(II); - Value *RealCol = nullptr; - if (isa<ConstantInt>(V)) - RealCol = - Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity); - else if (isa<Instruction>(V)) { - Builder.SetInsertPoint(cast<Instruction>(V)); - RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity)); - cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V)); - } else { - // When it is not a const value and it is a function argument, we create - // Row at the entry bb. - IRBuilder<> NewBuilder( - getFirstNonAllocaInTheEntryBlock(*II->getFunction())); - RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity)); - } - Row2Col[V] = RealCol; - return RealCol; -} - // TODO: Refine the row and col-in-bytes of tile to row and col of matrix. -std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, - unsigned OpNo) { - (void)TM; +std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { IRBuilder<> Builder(II); Value *Row = nullptr, *Col = nullptr; switch (II->getIntrinsicID()) { default: llvm_unreachable("Expect amx intrinsics"); - case Intrinsic::x86_t2rpntlvwz0_internal: - case Intrinsic::x86_t2rpntlvwz0t1_internal: - case Intrinsic::x86_t2rpntlvwz1_internal: - case Intrinsic::x86_t2rpntlvwz1t1_internal: case Intrinsic::x86_tileloadd64_internal: case Intrinsic::x86_tileloaddt164_internal: case Intrinsic::x86_tilestored64_internal: @@ -271,13 +204,6 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, } break; } - case Intrinsic::x86_ttransposed_internal: - case Intrinsic::x86_tconjtfp16_internal: { - assert((OpNo == 2) && "Illegal Operand Number."); - Row = getRowFromCol(II, II->getArgOperand(1), 4); - Col = getColFromRow(II, II->getArgOperand(0), 4); - break; - } case Intrinsic::x86_tcvtrowd2ps_internal: case Intrinsic::x86_tcvtrowps2bf16h_internal: case Intrinsic::x86_tcvtrowps2bf16l_internal: @@ -289,34 +215,12 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, Col = II->getArgOperand(1); break; } - case Intrinsic::x86_ttdpbf16ps_internal: - case Intrinsic::x86_ttdpfp16ps_internal: - case Intrinsic::x86_ttcmmimfp16ps_internal: - case Intrinsic::x86_ttcmmrlfp16ps_internal: - case Intrinsic::x86_tconjtcmmimfp16ps_internal: - case Intrinsic::x86_ttmmultf32ps_internal: { - switch (OpNo) { - case 3: - Row = II->getArgOperand(0); - Col = II->getArgOperand(1); - break; - case 4: - Row = getRowFromCol(II, II->getArgOperand(2), 4); - Col = getColFromRow(II, II->getArgOperand(0), 4); - break; - case 5: - Row = getRowFromCol(II, II->getArgOperand(2), 4); - Col = II->getArgOperand(1); - break; - } - break; - } } return std::make_pair(Row, Col); } -std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) { +static std::pair<Value *, Value *> getShape(PHINode *Phi) { Use &U = *(Phi->use_begin()); unsigned OpNo = U.getOperandNo(); User *V = U.getUser(); @@ -349,15 +253,14 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) { namespace { class X86LowerAMXType { Function &Func; - ShapeCalculator *SC; // In AMX intrinsics we let Shape = {Row, Col}, but the // RealCol = Col / ElementSize. We may use the RealCol // as a new Row for other new created AMX intrinsics. - std::map<Value *, Value *> Col2Row, Row2Col; + std::map<Value *, Value *> Col2Row; public: - X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {} + X86LowerAMXType(Function &F) : Func(F) {} bool visit(); void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); @@ -374,7 +277,7 @@ void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { Use &U = *(Bitcast->use_begin()); unsigned OpNo = U.getOperandNo(); auto *II = cast<IntrinsicInst>(U.getUser()); - std::tie(Row, Col) = SC->getShape(II, OpNo); + std::tie(Row, Col) = getShape(II, OpNo); IRBuilder<> Builder(Bitcast); // Use the maximun column as stride. Value *Stride = Builder.getInt64(64); @@ -454,7 +357,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { Builder.CreateStore(Src, AllocaAddr); // TODO we can pick an constant operand for the shape. Value *Row = nullptr, *Col = nullptr; - std::tie(Row, Col) = SC->getShape(II, OpNo); + std::tie(Row, Col) = getShape(II, OpNo); std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args); @@ -594,18 +497,11 @@ static Value *getAllocaPos(BasicBlock *BB) { static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); - auto *II = dyn_cast<IntrinsicInst>(TileDef); - unsigned Idx = 0; - // Extract tile from multiple tiles' def. - if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) { - assert(Extr->hasIndices() && "Tile extract miss index!"); - Idx = Extr->getIndices()[0]; - II = cast<IntrinsicInst>(Extr->getOperand(0)); - } + auto *II = cast<IntrinsicInst>(TileDef); assert(II && "Not tile intrinsic!"); - Value *Row = II->getOperand(Idx); - Value *Col = II->getOperand(Idx + 1); + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); BasicBlock *BB = TileDef->getParent(); BasicBlock::iterator Iter = TileDef->getIterator(); @@ -624,20 +520,14 @@ static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { // Get tile shape. IntrinsicInst *II = nullptr; - unsigned Idx = 0; if (IsPHI) { Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0); II = cast<IntrinsicInst>(PhiOp); - } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) { - // Extract tile from multiple tiles' def. - assert(Extr->hasIndices() && "Tile extract miss index!"); - Idx = Extr->getIndices()[0]; - II = cast<IntrinsicInst>(Extr->getOperand(0)); } else { II = cast<IntrinsicInst>(V); } - Value *Row = II->getOperand(Idx); - Value *Col = II->getOperand(Idx + 1); + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); Instruction *UserI = cast<Instruction>(U.getUser()); IRBuilder<> Builder(UserI); @@ -848,12 +738,10 @@ namespace { class X86LowerAMXCast { Function &Func; - ShapeCalculator *SC; std::unique_ptr<DominatorTree> DT; public: - X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC) - : Func(F), SC(ShapeC), DT(nullptr) {} + X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); bool combineTilezero(IntrinsicInst *Cast); @@ -932,7 +820,7 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi( if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue()) return false; Value *Row = nullptr, *Col = nullptr; - std::tie(Row, Col) = SC->getShape(OldPN); + std::tie(Row, Col) = getShape(OldPN); // TODO: If it is not constant the Row and Col must domoniate tilezero // that we are going to create. if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) @@ -1063,19 +951,6 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi( return true; } -static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx, - bool IsRow) { - if (!isAMXIntrinsic(Inst)) - return nullptr; - - auto *II = cast<IntrinsicInst>(Inst); - if (IsRow) - return II->getOperand(0); - - assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!"); - return II->getOperand(ShapeIdx + 1); -} - // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) // store <256 x i32> %43, <256 x i32>* %p, align 64 // --> @@ -1090,38 +965,13 @@ bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { if (!Tile->hasOneUse()) return false; - // We don't fetch shape from tilestore, we only get shape from tiledef, - // so we can set the max tile shape to tilestore for special cases. + auto *II = cast<IntrinsicInst>(Tile); + // Tile is output from AMX intrinsic. The first operand of the + // intrinsic is row, the second operand of the intrinsic is column. + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + IRBuilder<> Builder(ST); - Value *Row = nullptr; - Value *Col = nullptr; - - if (isAMXIntrinsic(Tile)) { - auto *II = cast<IntrinsicInst>(Tile); - // Tile is output from AMX intrinsic. The first operand of the - // intrinsic is row, the second operand of the intrinsic is column. - Row = II->getOperand(0); - Col = II->getOperand(1); - } else { - // Now we supported multi-tiles value in structure, so we may get tile - // from extracting multi-tiles structure. - // For example: - // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1, - // i16 %2, i16 %3, i8* %4, i64 %5) - // %7 = extractvalue { x86_amx, x86_amx } %6, 0 - // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7) - // store <256 x i32> %8, <256 x i32>* %0, align 1024 - // - // TODO: Currently we only handle extractvalue case, enhance me for other - // cases if possible. - auto *II = cast<ExtractValueInst>(Tile); - assert(II && "We meet unhandle source in fetching tile value!"); - unsigned ShapeIdx = II->getIndices()[0]; - Value *Tiles = II->getOperand(0); - Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true); - Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false); - } - assert(Row && Col && "Shape got failed!"); // Stride should be equal to col(measured by bytes) Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); @@ -1146,7 +996,7 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { // shape information through def-use chain. if (!isAMXIntrinsic(II)) return false; - std::tie(Row, Col) = SC->getShape(II, OpNo); + std::tie(Row, Col) = getShape(II, OpNo); IRBuilder<> Builder(LD); // Stride should be equal to col(measured by bytes) Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); @@ -1189,7 +1039,7 @@ bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) { if (!isAMXIntrinsic(II)) return false; - std::tie(Row, Col) = SC->getShape(II, OpNo); + std::tie(Row, Col) = getShape(II, OpNo); IRBuilder<> Builder(Cast); Value *NewInst = @@ -1384,7 +1234,7 @@ bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { Builder.CreateStore(Src, AllocaAddr); // TODO we can pick an constant operand for the shape. Value *Row = nullptr, *Col = nullptr; - std::tie(Row, Col) = SC->getShape(II, OpNo); + std::tie(Row, Col) = getShape(II, OpNo); std::array<Value *, 4> Args = { Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; Value *NewInst = @@ -1445,14 +1295,13 @@ bool lowerAmxType(Function &F, const TargetMachine *TM, return false; bool C = false; - ShapeCalculator SC(TM); - X86LowerAMXCast LAC(F, &SC); + X86LowerAMXCast LAC(F); C |= LAC.combineAMXcast(TLI); // There might be remaining AMXcast after combineAMXcast and they should be // handled elegantly. C |= LAC.transformAllAMXCast(); - X86LowerAMXType LAT(F, &SC); + X86LowerAMXType LAT(F); C |= LAT.visit(); // Prepare for fast register allocation at O0. diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp index 2a1c499..8a1d00d 100644 --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -141,15 +141,10 @@ class X86PreTileConfig : public MachineFunctionPass { if (!MO.isReg() || !MO.getReg().isVirtual()) return false; - unsigned Shapes = 0; - if (MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) - Shapes = 1; - if (MRI->getRegClass(MO.getReg())->getID() == X86::TILEPAIRRegClassID) - Shapes = 2; - if (!Shapes) + if (MRI->getRegClass(MO.getReg())->getID() != X86::TILERegClassID) return false; - collectShapeInfo(MI, Shapes); + collectShapeInfo(MI); return true; } @@ -165,7 +160,7 @@ class X86PreTileConfig : public MachineFunctionPass { } /// Collect the shape def information for later use. - void collectShapeInfo(MachineInstr &MI, unsigned Shapes); + void collectShapeInfo(MachineInstr &MI); /// Try to hoist shapes definded below AMX instructions. bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { @@ -231,7 +226,7 @@ INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass) INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", "Tile Register Pre-configure", false, false) -void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) { +void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { MIRef MIR(MI, MBB); auto &Refs = ShapeBBs[MBB]; @@ -240,10 +235,8 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) { Refs.insert(I, MIR); }; - // All shapes have same row in multi-tile operand. - SmallVector<Register, 8> WorkList; - for (unsigned I = 1; I < Shapes + 2; ++I) - WorkList.push_back(MI.getOperand(I).getReg()); + SmallVector<Register, 8> WorkList( + {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); while (!WorkList.empty()) { Register R = WorkList.pop_back_val(); MachineInstr *DefMI = MRI->getVRegDef(R); @@ -252,13 +245,6 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) { if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) continue; - // This happens when column = 0 in multi-tile operand. - if (DefMI->getOpcode() == X86::COPY) { - MachineInstr *MI = MRI->getVRegDef(DefMI->getOperand(1).getReg()); - if (MI && MI->isMoveImmediate()) - continue; - } - if (DefMI->isPHI()) { for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp index 76979e3..72f3813 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -597,10 +597,6 @@ BitVector X86RegisterInfo::getReservedRegs(const MachineFunction &MF) const { Reserved.set(*AI); } - // Reserve low half pair registers in case they are used by RA aggressively. - Reserved.set(X86::TMM0_TMM1); - Reserved.set(X86::TMM2_TMM3); - assert(checkAllSuperRegsMarked(Reserved, {X86::SIL, X86::DIL, X86::BPL, X86::SPL, X86::SIH, X86::DIH, X86::BPH, X86::SPH})); @@ -621,7 +617,7 @@ unsigned X86RegisterInfo::getNumSupportedRegs(const MachineFunction &MF) const { // and try to return the minimum number of registers supported by the target. static_assert((X86::R15WH + 1 == X86::YMM0) && (X86::YMM15 + 1 == X86::K0) && (X86::K6_K7 + 1 == X86::TMMCFG) && - (X86::TMM6_TMM7 + 1 == X86::R16) && + (X86::TMM7 + 1 == X86::R16) && (X86::R31WH + 1 == X86::NUM_TARGET_REGS), "Register number may be incorrect"); @@ -694,8 +690,7 @@ bool X86RegisterInfo::isFixedRegister(const MachineFunction &MF, } bool X86RegisterInfo::isTileRegisterClass(const TargetRegisterClass *RC) const { - return RC->getID() == X86::TILERegClassID || - RC->getID() == X86::TILEPAIRRegClassID; + return RC->getID() == X86::TILERegClassID; } void X86RegisterInfo::adjustStackMapLiveOutMask(uint32_t *Mask) const { @@ -1062,17 +1057,9 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM, case X86::PTDPFP16PSV: case X86::PTCMMIMFP16PSV: case X86::PTCMMRLFP16PSV: - case X86::PTTRANSPOSEDV: - case X86::PTTDPBF16PSV: - case X86::PTTDPFP16PSV: - case X86::PTTCMMIMFP16PSV: - case X86::PTTCMMRLFP16PSV: - case X86::PTCONJTCMMIMFP16PSV: - case X86::PTCONJTFP16V: case X86::PTILELOADDRSV: case X86::PTILELOADDRST1V: case X86::PTMMULTF32PSV: - case X86::PTTMMULTF32PSV: case X86::PTDPBF8PSV: case X86::PTDPBHF8PSV: case X86::PTDPHBF8PSV: @@ -1083,56 +1070,7 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM, VRM->assignVirt2Shape(VirtReg, Shape); return Shape; } - case X86::PT2RPNTLVWZ0V: - case X86::PT2RPNTLVWZ0T1V: - case X86::PT2RPNTLVWZ1V: - case X86::PT2RPNTLVWZ1T1V: - case X86::PT2RPNTLVWZ0RSV: - case X86::PT2RPNTLVWZ0RST1V: - case X86::PT2RPNTLVWZ1RSV: - case X86::PT2RPNTLVWZ1RST1V: { - MachineOperand &MO1 = MI->getOperand(1); - MachineOperand &MO2 = MI->getOperand(2); - MachineOperand &MO3 = MI->getOperand(3); - ShapeT Shape({&MO1, &MO2, &MO1, &MO3}, MRI); - VRM->assignVirt2Shape(VirtReg, Shape); - return Shape; - } - } -} - -static bool canHintShape(ShapeT &PhysShape, ShapeT &VirtShape) { - unsigned PhysShapeNum = PhysShape.getShapeNum(); - unsigned VirtShapeNum = VirtShape.getShapeNum(); - - if (PhysShapeNum < VirtShapeNum) - return false; - - if (PhysShapeNum == VirtShapeNum) { - if (PhysShapeNum == 1) - return PhysShape == VirtShape; - - for (unsigned I = 0; I < PhysShapeNum; I++) { - ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I)); - ShapeT VShape(VirtShape.getRow(I), VirtShape.getCol(I)); - if (VShape != PShape) - return false; - } - return true; - } - - // Hint subreg of mult-tile reg to single tile reg. - if (VirtShapeNum == 1) { - for (unsigned I = 0; I < PhysShapeNum; I++) { - ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I)); - if (VirtShape == PShape) - return true; - } } - - // Note: Currently we have no requirement for case of - // (VirtShapeNum > 1 and PhysShapeNum > VirtShapeNum) - return false; } bool X86RegisterInfo::getRegAllocationHints(Register VirtReg, @@ -1153,7 +1091,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg, if (!VRM) return BaseImplRetVal; - if (ID != X86::TILERegClassID && ID != X86::TILEPAIRRegClassID) { + if (ID != X86::TILERegClassID) { if (DisableRegAllocNDDHints || !ST.hasNDD() || !TRI.isGeneralPurposeRegisterClass(&RC)) return BaseImplRetVal; @@ -1204,7 +1142,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg, return; } ShapeT PhysShape = getTileShape(VReg, const_cast<VirtRegMap *>(VRM), MRI); - if (canHintShape(PhysShape, VirtShape)) + if (PhysShape == VirtShape) Hints.push_back(PhysReg); }; diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td index 99b7910..692e42a 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.td +++ b/llvm/lib/Target/X86/X86RegisterInfo.td @@ -30,8 +30,6 @@ let Namespace = "X86" in { def sub_ymm : SubRegIndex<256>; def sub_mask_0 : SubRegIndex<-1>; def sub_mask_1 : SubRegIndex<-1, -1>; - def sub_t0 : SubRegIndex<8192>; - def sub_t1 : SubRegIndex<8192, 8192>; } //===----------------------------------------------------------------------===// @@ -432,10 +430,6 @@ def TMM4: X86Reg<"tmm4", 4>; def TMM5: X86Reg<"tmm5", 5>; def TMM6: X86Reg<"tmm6", 6>; def TMM7: X86Reg<"tmm7", 7>; -// TMM register pairs -def TPAIRS : RegisterTuples<[sub_t0, sub_t1], - [(add TMM0, TMM2, TMM4, TMM6), - (add TMM1, TMM3, TMM5, TMM7)]>; } // Floating point stack registers. These don't map one-to-one to the FP @@ -862,9 +856,6 @@ def VK64WM : RegisterClass<"X86", [v64i1], 64, (add VK32WM)> {let Size = 64;} let CopyCost = -1 in // Don't allow copying of tile registers def TILE : RegisterClass<"X86", [x86amx], 8192, (sequence "TMM%u", 0, 7)> {let Size = 8192;} -// Need check alignment 3rd operand size=1024*2*8 -let isAllocatable = 1 in -def TILEPAIR : RegisterClass<"X86", [untyped], 512, (add TPAIRS)> {let Size = 16384;} //===----------------------------------------------------------------------===// // Register categories. diff --git a/llvm/lib/Target/X86/X86TileConfig.cpp b/llvm/lib/Target/X86/X86TileConfig.cpp index 17a44dd..09ef8fb 100644 --- a/llvm/lib/Target/X86/X86TileConfig.cpp +++ b/llvm/lib/Target/X86/X86TileConfig.cpp @@ -74,63 +74,6 @@ INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy) INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false, false) -unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) { - if (Reg.isVirtual()) { - unsigned RegClassID = MRI->getRegClass(Reg)->getID(); - if (RegClassID == X86::TILERegClassID) - return 1; - if (RegClassID == X86::TILEPAIRRegClassID) - return 2; - } else { - if (Reg >= X86::TMM0 && Reg <= X86::TMM7) - return 1; - if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) - return 2; - } - return 0; -} - -static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM, - Register VirtReg, - SmallVector<ShapeT, 8> &Phys2Shapes) { - unsigned Num = getAMXRegNum(MRI, VirtReg); - MCRegister PhysReg = VRM.getPhys(VirtReg); - if (!PhysReg) - return; - - if (Num == 1) { - unsigned Index = PhysReg - X86::TMM0; - if (!Phys2Shapes[Index].isValid()) { - ShapeT Shape = VRM.getShape(VirtReg); - Phys2Shapes[Index] = std::move(Shape); - return; - } - } - // Split tile pair shape info to 2 single tile shape info. e.g: - // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes. - if (Num == 2) { - unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2; - unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1; - - ShapeT Shape = VRM.getShape(VirtReg); - assert(Shape.getShapeNum() == 2 && "Unexpected shape number!"); - - if (!Phys2Shapes[Index0].isValid()) { - ShapeT Shape0(Shape.getRow(0), Shape.getCol(0), MRI); - Phys2Shapes[Index0] = std::move(Shape0); - } - - if (!Phys2Shapes[Index1].isValid()) { - ShapeT Shape1(Shape.getRow(1), Shape.getCol(1), MRI); - Phys2Shapes[Index1] = std::move(Shape1); - } - } -} - -static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) { - return getAMXRegNum(MRI, Reg) > 0; -} - bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) { X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); // Early exit in the common case of non-AMX code. @@ -138,7 +81,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) { return false; const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); - const TargetRegisterInfo *TRI = ST.getRegisterInfo(); + const X86RegisterInfo *TRI = ST.getRegisterInfo(); const TargetInstrInfo *TII = ST.getInstrInfo(); MachineRegisterInfo &MRI = MF.getRegInfo(); LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); @@ -176,24 +119,29 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) { assert(ConstMI && "Cannot find an insertion point"); unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs(); - SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT()); + SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0); for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { Register VirtReg = Register::index2VirtReg(I); if (MRI.reg_nodbg_empty(VirtReg)) continue; - if (!isAMXRegClass(&MRI, VirtReg)) + if (!TRI->isTileRegisterClass(MRI.getRegClass(VirtReg))) + continue; + MCRegister PhysReg = VRM.getPhys(VirtReg); + if (!PhysReg) continue; - collectVirtRegShapes(&MRI, VRM, VirtReg, Phys2Shapes); + unsigned Index = PhysReg - X86::TMM0; + if (!Phys2Virt[Index]) + Phys2Virt[Index] = VirtReg; } // Fill in the shape of each tile physical register. for (unsigned I = 0; I < AMXRegNum; ++I) { - ShapeT Shape = Phys2Shapes[I]; - if (!Shape.isValid()) + if (!Phys2Virt[I]) continue; DebugLoc DL; bool IsRow = true; MachineInstr *NewMI = nullptr; + ShapeT Shape = VRM.getShape(Phys2Virt[I]); for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) { // Here is the data format for the tile config. // 0 palette @@ -222,14 +170,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) { "Cannot initialize with different shapes"); continue; } - if (DefMI.getOperand(1).isImm()) { - Imm = DefMI.getOperand(1).getImm(); - } else { - assert(DefMI.getOpcode() == X86::MOV32r0 && - "The opcode is assumed to be MOV32r0 if the operand is not " - "immediate."); - Imm = 0; - } + Imm = DefMI.getOperand(1).getImm(); NewMI = addFrameReference( BuildMI(MF.front(), ++ConstMI->getIterator(), DL, |
