diff options
Diffstat (limited to 'llvm/lib/Target/X86')
20 files changed, 265 insertions, 1151 deletions
| diff --git a/llvm/lib/Target/X86/AsmParser/X86Operand.h b/llvm/lib/Target/X86/AsmParser/X86Operand.h index 89ac53e..a922725 100644 --- a/llvm/lib/Target/X86/AsmParser/X86Operand.h +++ b/llvm/lib/Target/X86/AsmParser/X86Operand.h @@ -620,37 +620,6 @@ struct X86Operand final : public MCParsedAsmOperand {      Inst.addOperand(MCOperand::createReg(Reg));    } -  bool isTILEPair() const { -    return Kind == Register && -           X86MCRegisterClasses[X86::TILERegClassID].contains(getReg()); -  } - -  void addTILEPairOperands(MCInst &Inst, unsigned N) const { -    assert(N == 1 && "Invalid number of operands!"); -    MCRegister Reg = getReg(); -    switch (Reg.id()) { -    default: -      llvm_unreachable("Invalid tile register!"); -    case X86::TMM0: -    case X86::TMM1: -      Reg = X86::TMM0_TMM1; -      break; -    case X86::TMM2: -    case X86::TMM3: -      Reg = X86::TMM2_TMM3; -      break; -    case X86::TMM4: -    case X86::TMM5: -      Reg = X86::TMM4_TMM5; -      break; -    case X86::TMM6: -    case X86::TMM7: -      Reg = X86::TMM6_TMM7; -      break; -    } -    Inst.addOperand(MCOperand::createReg(Reg)); -  } -    void addMemOperands(MCInst &Inst, unsigned N) const {      assert((N == 5) && "Invalid number of operands!");      if (getMemBaseReg()) diff --git a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp index 4927b45..7d2b5eb 100644 --- a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp +++ b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp @@ -810,10 +810,6 @@ static int readModRM(struct InternalInstruction *insn) {        if (index > 7)                                                           \          *valid = 0;                                                            \        return prefix##_TMM0 + index;                                            \ -    case TYPE_TMM_PAIR:                                                        \ -      if (index > 7)                                                           \ -        *valid = 0;                                                            \ -      return prefix##_TMM0_TMM1 + (index / 2);                                 \      case TYPE_VK:                                                              \        index &= 0xf;                                                            \        if (index > 7)                                                           \ @@ -2323,7 +2319,6 @@ static bool translateRM(MCInst &mcInst, const OperandSpecifier &operand,    case TYPE_YMM:    case TYPE_ZMM:    case TYPE_TMM: -  case TYPE_TMM_PAIR:    case TYPE_VK_PAIR:    case TYPE_VK:    case TYPE_DEBUGREG: diff --git a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h index dc9af2c..b0aa70b 100644 --- a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h +++ b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h @@ -535,12 +535,6 @@ namespace X86Disassembler {    ENTRY(TMM6)                                                                  \    ENTRY(TMM7) -#define REGS_TMM_PAIRS                                                         \ -  ENTRY(TMM0_TMM1)                                                             \ -  ENTRY(TMM2_TMM3)                                                             \ -  ENTRY(TMM4_TMM5)                                                             \ -  ENTRY(TMM6_TMM7) -  #define ALL_EA_BASES                                                           \    EA_BASES_16BIT                                                               \    EA_BASES_32BIT                                                               \ @@ -565,7 +559,6 @@ namespace X86Disassembler {    REGS_DEBUG                                                                   \    REGS_CONTROL                                                                 \    REGS_TMM                                                                     \ -  REGS_TMM_PAIRS                                                               \    ENTRY(RIP)  /// All possible values of the base field for effective-address diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp index 1c5f166..759d95e 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp @@ -467,22 +467,3 @@ void X86InstPrinterCommon::printVKPair(const MCInst *MI, unsigned OpNo,    }    llvm_unreachable("Unknown mask pair register name");  } - -void X86InstPrinterCommon::printTILEPair(const MCInst *MI, unsigned OpNo, -                                         raw_ostream &OS) { -  switch (MI->getOperand(OpNo).getReg()) { -  case X86::TMM0_TMM1: -    printRegName(OS, X86::TMM0); -    return; -  case X86::TMM2_TMM3: -    printRegName(OS, X86::TMM2); -    return; -  case X86::TMM4_TMM5: -    printRegName(OS, X86::TMM4); -    return; -  case X86::TMM6_TMM7: -    printRegName(OS, X86::TMM6); -    return; -  } -  llvm_unreachable("Unknown mask pair register name"); -} diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h index 2c9467c..cb55f2f 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h @@ -40,7 +40,6 @@ protected:                        const MCSubtargetInfo &STI);    void printOptionalSegReg(const MCInst *MI, unsigned OpNo, raw_ostream &O);    void printVKPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS); -  void printTILEPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS);  };  } // end namespace llvm diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td index a1fd366..9e291a6 100644 --- a/llvm/lib/Target/X86/X86.td +++ b/llvm/lib/Target/X86/X86.td @@ -274,9 +274,6 @@ def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true",  def FeatureAMXMOVRS : SubtargetFeature<"amx-movrs", "HasAMXMOVRS", "true",                                         "Support AMX-MOVRS instructions",                                         [FeatureAMXTILE]>; -def FeatureAMXTRANSPOSE : SubtargetFeature<"amx-transpose", "HasAMXTRANSPOSE", "true", -                                           "Support AMX amx-transpose instructions", -                                           [FeatureAMXTILE]>;  def FeatureAMXAVX512 : SubtargetFeature<"amx-avx512",                                          "HasAMXAVX512", "true",                                          "Support AMX-AVX512 instructions", @@ -1177,8 +1174,7 @@ def ProcessorFeatures {                                                    FeatureAMXMOVRS,                                                    FeatureAMXAVX512,                                                    FeatureAMXFP8, -                                                  FeatureAMXTF32, -                                                  FeatureAMXTRANSPOSE]; +                                                  FeatureAMXTF32];    list<SubtargetFeature> DMRFeatures =      !listconcat(GNRDFeatures, DMRAdditionalFeatures); diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp index 4a9b824..e3c44c0 100644 --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -649,149 +649,6 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,      MI.setDesc(TII->get(Opc));      return true;    } -  // TILEPAIRLOAD is just for TILEPair spill, we don't have corresponding -  // AMX instruction to support it. So, split it to 2 load instructions: -  // "TILEPAIRLOAD TMM0:TMM1, Base, Scale, Index, Offset, Segment" --> -  // "TILELOAD TMM0, Base, Scale, Index, Offset, Segment" + -  // "TILELOAD TMM1, Base, Scale, Index, Offset + TMM_SIZE, Segment" -  case X86::PTILEPAIRLOAD: { -    int64_t Disp = MBBI->getOperand(1 + X86::AddrDisp).getImm(); -    Register TReg = MBBI->getOperand(0).getReg(); -    bool DstIsDead = MBBI->getOperand(0).isDead(); -    Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0); -    Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1); -    unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8; - -    MachineInstrBuilder MIBLo = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD)) -            .addReg(TReg0, RegState::Define | getDeadRegState(DstIsDead)); -    MachineInstrBuilder MIBHi = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD)) -            .addReg(TReg1, RegState::Define | getDeadRegState(DstIsDead)); - -    for (int i = 0; i < X86::AddrNumOperands; ++i) { -      MIBLo.add(MBBI->getOperand(1 + i)); -      if (i == X86::AddrDisp) -        MIBHi.addImm(Disp + TmmSize); -      else -        MIBHi.add(MBBI->getOperand(1 + i)); -    } - -    // Make sure the first stride reg used in first tileload is alive. -    MachineOperand &Stride = -        MIBLo.getInstr()->getOperand(1 + X86::AddrIndexReg); -    Stride.setIsKill(false); - -    // Split the memory operand, adjusting the offset and size for the halves. -    MachineMemOperand *OldMMO = MBBI->memoperands().front(); -    MachineFunction *MF = MBB.getParent(); -    MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize); -    MachineMemOperand *MMOHi = -        MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize); - -    MIBLo.setMemRefs(MMOLo); -    MIBHi.setMemRefs(MMOHi); - -    // Delete the pseudo. -    MBB.erase(MBBI); -    return true; -  } -  // Similar with TILEPAIRLOAD, TILEPAIRSTORE is just for TILEPair spill, no -  // corresponding AMX instruction to support it. So, split it too: -  // "TILEPAIRSTORE Base, Scale, Index, Offset, Segment, TMM0:TMM1" --> -  // "TILESTORE Base, Scale, Index, Offset, Segment, TMM0" + -  // "TILESTORE Base, Scale, Index, Offset + TMM_SIZE, Segment, TMM1" -  case X86::PTILEPAIRSTORE: { -    int64_t Disp = MBBI->getOperand(X86::AddrDisp).getImm(); -    Register TReg = MBBI->getOperand(X86::AddrNumOperands).getReg(); -    bool SrcIsKill = MBBI->getOperand(X86::AddrNumOperands).isKill(); -    Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0); -    Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1); -    unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8; - -    MachineInstrBuilder MIBLo = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED)); -    MachineInstrBuilder MIBHi = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED)); - -    for (int i = 0; i < X86::AddrNumOperands; ++i) { -      MIBLo.add(MBBI->getOperand(i)); -      if (i == X86::AddrDisp) -        MIBHi.addImm(Disp + TmmSize); -      else -        MIBHi.add(MBBI->getOperand(i)); -    } -    MIBLo.addReg(TReg0, getKillRegState(SrcIsKill)); -    MIBHi.addReg(TReg1, getKillRegState(SrcIsKill)); - -    // Make sure the first stride reg used in first tilestore is alive. -    MachineOperand &Stride = MIBLo.getInstr()->getOperand(X86::AddrIndexReg); -    Stride.setIsKill(false); - -    // Split the memory operand, adjusting the offset and size for the halves. -    MachineMemOperand *OldMMO = MBBI->memoperands().front(); -    MachineFunction *MF = MBB.getParent(); -    MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize); -    MachineMemOperand *MMOHi = -        MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize); - -    MIBLo.setMemRefs(MMOLo); -    MIBHi.setMemRefs(MMOHi); - -    // Delete the pseudo. -    MBB.erase(MBBI); -    return true; -  } -  case X86::PT2RPNTLVWZ0V: -  case X86::PT2RPNTLVWZ0T1V: -  case X86::PT2RPNTLVWZ1V: -  case X86::PT2RPNTLVWZ1T1V: -  case X86::PT2RPNTLVWZ0RSV: -  case X86::PT2RPNTLVWZ0RST1V: -  case X86::PT2RPNTLVWZ1RSV: -  case X86::PT2RPNTLVWZ1RST1V: { -    for (unsigned i = 3; i > 0; --i) -      MI.removeOperand(i); -    unsigned Opc; -    switch (Opcode) { -    case X86::PT2RPNTLVWZ0V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0); -      break; -    case X86::PT2RPNTLVWZ0T1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1); -      break; -    case X86::PT2RPNTLVWZ1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1); -      break; -    case X86::PT2RPNTLVWZ1T1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1); -      break; -    case X86::PT2RPNTLVWZ0RSV: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS); -      break; -    case X86::PT2RPNTLVWZ0RST1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1); -      break; -    case X86::PT2RPNTLVWZ1RSV: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS); -      break; -    case X86::PT2RPNTLVWZ1RST1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1); -      break; -    default: -      llvm_unreachable("Impossible Opcode!"); -    } -    MI.setDesc(TII->get(Opc)); -    return true; -  } -  case X86::PTTRANSPOSEDV: -  case X86::PTCONJTFP16V: { -    for (int i = 2; i > 0; --i) -      MI.removeOperand(i); -    MI.setDesc(TII->get(Opcode == X86::PTTRANSPOSEDV ? X86::TTRANSPOSED -                                                     : X86::TCONJTFP16)); -    return true; -  }    case X86::PTCMMIMFP16PSV:    case X86::PTCMMRLFP16PSV:    case X86::PTDPBSSDV: @@ -800,13 +657,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,    case X86::PTDPBUUDV:    case X86::PTDPBF16PSV:    case X86::PTDPFP16PSV: -  case X86::PTTDPBF16PSV: -  case X86::PTTDPFP16PSV: -  case X86::PTTCMMIMFP16PSV: -  case X86::PTTCMMRLFP16PSV: -  case X86::PTCONJTCMMIMFP16PSV:    case X86::PTMMULTF32PSV: -  case X86::PTTMMULTF32PSV:    case X86::PTDPBF8PSV:    case X86::PTDPBHF8PSV:    case X86::PTDPHBF8PSV: @@ -816,6 +667,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,        MI.removeOperand(i);      unsigned Opc;      switch (Opcode) { +      // clang-format off      case X86::PTCMMIMFP16PSV:  Opc = X86::TCMMIMFP16PS; break;      case X86::PTCMMRLFP16PSV:  Opc = X86::TCMMRLFP16PS; break;      case X86::PTDPBSSDV:   Opc = X86::TDPBSSD; break; @@ -824,40 +676,12 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,      case X86::PTDPBUUDV:   Opc = X86::TDPBUUD; break;      case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break;      case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break; -    case X86::PTTDPBF16PSV: -      Opc = X86::TTDPBF16PS; -      break; -    case X86::PTTDPFP16PSV: -      Opc = X86::TTDPFP16PS; -      break; -    case X86::PTTCMMIMFP16PSV: -      Opc = X86::TTCMMIMFP16PS; -      break; -    case X86::PTTCMMRLFP16PSV: -      Opc = X86::TTCMMRLFP16PS; -      break; -    case X86::PTCONJTCMMIMFP16PSV: -      Opc = X86::TCONJTCMMIMFP16PS; -      break; -    case X86::PTMMULTF32PSV: -      Opc = X86::TMMULTF32PS; -      break; -    case X86::PTTMMULTF32PSV: -      Opc = X86::TTMMULTF32PS; -      break; -    case X86::PTDPBF8PSV: -      Opc = X86::TDPBF8PS; -      break; -    case X86::PTDPBHF8PSV: -      Opc = X86::TDPBHF8PS; -      break; -    case X86::PTDPHBF8PSV: -      Opc = X86::TDPHBF8PS; -      break; -    case X86::PTDPHF8PSV: -      Opc = X86::TDPHF8PS; -      break; - +    case X86::PTMMULTF32PSV: Opc = X86::TMMULTF32PS; break; +    case X86::PTDPBF8PSV: Opc = X86::TDPBF8PS; break; +    case X86::PTDPBHF8PSV: Opc = X86::TDPBHF8PS; break; +    case X86::PTDPHBF8PSV: Opc = X86::TDPHBF8PS; break; +    case X86::PTDPHF8PSV: Opc = X86::TDPHF8PS; break; +    // clang-format on      default:        llvm_unreachable("Unexpected Opcode");      } diff --git a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp index 787b71d..06f729a 100644 --- a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp @@ -267,24 +267,16 @@ void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI,                      << printReg(TileReg, TRI) << '\n');  } -static unsigned getTileDefNum(MachineRegisterInfo *MRI, Register Reg) { -  if (Reg.isVirtual()) { -    unsigned RegClassID = MRI->getRegClass(Reg)->getID(); -    if (RegClassID == X86::TILERegClassID) -      return 1; -    if (RegClassID == X86::TILEPAIRRegClassID) -      return 2; -  } else { -    if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; +static bool isTileRegister(MachineRegisterInfo *MRI, Register Reg) { +  if (Reg.isVirtual() && +      (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)) { +    return true;    } -  return 0; -} -static bool isTileRegister(MachineRegisterInfo *MRI, Register VirtReg) { -  return getTileDefNum(MRI, VirtReg) > 0; +  if (Reg >= X86::TMM0 && Reg <= X86::TMM7) +    return true; + +  return false;  }  static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) { @@ -296,7 +288,7 @@ static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {    if (!MO.isReg())      return false; -  return getTileDefNum(MRI, MO.getReg()) > 0; +  return isTileRegister(MRI, MO.getReg());  }  static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) { @@ -636,19 +628,7 @@ bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) {        else if (dominates(MBB, LastShapeMI, ColMI))          LastShapeMI = ColMI;      } -    unsigned TileDefNum = getTileDefNum(MRI, MI.getOperand(0).getReg()); -    if (TileDefNum > 1) { -      for (unsigned I = 1; I < TileDefNum; I++) { -        MachineOperand *ColxMO = &MI.getOperand(2 + I); -        MachineInstr *ColxMI = MRI->getVRegDef(ColxMO->getReg()); -        if (ColxMI->getParent() == &MBB) { -          if (!LastShapeMI) -            LastShapeMI = ColxMI; -          else if (dominates(MBB, LastShapeMI, ColxMI)) -            LastShapeMI = ColxMI; -        } -      } -    } +      // If there is user live out of the tilecfg, spill it and reload in      // before the user.      Register TileReg = MI.getOperand(0).getReg(); diff --git a/llvm/lib/Target/X86/X86FastTileConfig.cpp b/llvm/lib/Target/X86/X86FastTileConfig.cpp index 11d331b..d86ae36 100644 --- a/llvm/lib/Target/X86/X86FastTileConfig.cpp +++ b/llvm/lib/Target/X86/X86FastTileConfig.cpp @@ -77,14 +77,14 @@ INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,  INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,                      "Fast Tile Register Configure", false, false) -static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) { +static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {    // There is no phi instruction after register allocation.    assert(MI.isPHI() == false);    // The instruction must have 3 operands: tile def, row, col.    // It should be AMX pseudo instruction that have shape operand.    if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 ||        !MI.isPseudo()) -    return 0; +    return false;    MachineOperand &MO = MI.getOperand(0);    if (MO.isReg()) { @@ -93,24 +93,18 @@ static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) {      // register is not rewritten yet.      if (Reg.isVirtual()) {        if (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID) -        return 1; -      if (MRI->getRegClass(Reg)->getID() == X86::TILEPAIRRegClassID) -        return 2; +        return true;      }      if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; +      return true;    } -  return 0; +  return false;  }  static unsigned getTMMIndex(Register Reg) {    if (Reg >= X86::TMM0 && Reg <= X86::TMM7)      return Reg - X86::TMM0; -  if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -    return (Reg - X86::TMM0_TMM1) * 2;    llvm_unreachable("Invalid Tmm Reg!");  } @@ -120,17 +114,14 @@ bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) {    bool Change = false;    SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos;    for (MachineInstr &MI : reverse(MBB)) { -    unsigned DefNum = getNumDefTiles(MRI, MI); -    if (DefNum == 0 && MI.getOpcode() != X86::PLDTILECFGV) +    if (!isTileDef(MRI, MI) && MI.getOpcode() != X86::PLDTILECFGV)        continue;      // AMX instructions that define tile register.      if (MI.getOpcode() != X86::PLDTILECFGV) {        MachineOperand &Row = MI.getOperand(1);        unsigned TMMIdx = getTMMIndex(MI.getOperand(0).getReg()); -      for (unsigned I = 0; I < DefNum; I++) { -        MachineOperand &Col = MI.getOperand(2 + I); -        ShapeInfos.push_back({TMMIdx + I, ShapeT(&Row, &Col)}); -      } +      MachineOperand &Col = MI.getOperand(2); +      ShapeInfos.push_back({TMMIdx, ShapeT(&Row, &Col)});      } else { // PLDTILECFGV        // Rewrite the shape information to memory. Stack slot should have        // been initialized to zero in pre config. diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 4393f6e..d4418c8 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -337,23 +337,8 @@ namespace {      // lowering but before ISEL.      bool isAMXSDNode(SDNode *N) const {        // Check if N is AMX SDNode: -      // 1. check specific opcode since these carry MVT::Untyped instead of -      // x86amx_type; -      // 2. check result type; -      // 3. check operand type; -      switch (N->getOpcode()) { -      default: -        break; -      case X86::PT2RPNTLVWZ0V: -      case X86::PT2RPNTLVWZ0T1V: -      case X86::PT2RPNTLVWZ1V: -      case X86::PT2RPNTLVWZ1T1V: -      case X86::PT2RPNTLVWZ0RSV: -      case X86::PT2RPNTLVWZ0RST1V: -      case X86::PT2RPNTLVWZ1RSV: -      case X86::PT2RPNTLVWZ1RST1V: -        return true; -      } +      // 1. check result type; +      // 2. check operand type;        for (unsigned Idx = 0, E = N->getNumValues(); Idx != E; ++Idx) {          if (N->getValueType(Idx) == MVT::x86amx)            return true; @@ -5398,65 +5383,6 @@ void X86DAGToDAGISel::Select(SDNode *Node) {        ReplaceNode(Node, CNode);        return;      } -    case Intrinsic::x86_t2rpntlvwz0rs: -    case Intrinsic::x86_t2rpntlvwz0rst1: -    case Intrinsic::x86_t2rpntlvwz1rs: -    case Intrinsic::x86_t2rpntlvwz1rst1: -      if (!Subtarget->hasAMXMOVRS()) -        break; -      [[fallthrough]]; -    case Intrinsic::x86_t2rpntlvwz0: -    case Intrinsic::x86_t2rpntlvwz0t1: -    case Intrinsic::x86_t2rpntlvwz1: -    case Intrinsic::x86_t2rpntlvwz1t1: { -      if (!Subtarget->hasAMXTRANSPOSE()) -        break; -      auto *MFI = -          CurDAG->getMachineFunction().getInfo<X86MachineFunctionInfo>(); -      MFI->setAMXProgModel(AMXProgModelEnum::DirectReg); -      unsigned Opc; -      switch (IntNo) { -      default: -        llvm_unreachable("Unexpected intrinsic!"); -      case Intrinsic::x86_t2rpntlvwz0: -        Opc = X86::PT2RPNTLVWZ0; -        break; -      case Intrinsic::x86_t2rpntlvwz0t1: -        Opc = X86::PT2RPNTLVWZ0T1; -        break; -      case Intrinsic::x86_t2rpntlvwz1: -        Opc = X86::PT2RPNTLVWZ1; -        break; -      case Intrinsic::x86_t2rpntlvwz1t1: -        Opc = X86::PT2RPNTLVWZ1T1; -        break; -      case Intrinsic::x86_t2rpntlvwz0rs: -        Opc = X86::PT2RPNTLVWZ0RS; -        break; -      case Intrinsic::x86_t2rpntlvwz0rst1: -        Opc = X86::PT2RPNTLVWZ0RST1; -        break; -      case Intrinsic::x86_t2rpntlvwz1rs: -        Opc = X86::PT2RPNTLVWZ1RS; -        break; -      case Intrinsic::x86_t2rpntlvwz1rst1: -        Opc = X86::PT2RPNTLVWZ1RST1; -        break; -      } -      // FIXME: Match displacement and scale. -      unsigned TIndex = Node->getConstantOperandVal(2); -      SDValue TReg = getI8Imm(TIndex, dl); -      SDValue Base = Node->getOperand(3); -      SDValue Scale = getI8Imm(1, dl); -      SDValue Index = Node->getOperand(4); -      SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32); -      SDValue Segment = CurDAG->getRegister(0, MVT::i16); -      SDValue Chain = Node->getOperand(0); -      SDValue Ops[] = {TReg, Base, Scale, Index, Disp, Segment, Chain}; -      MachineSDNode *CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops); -      ReplaceNode(Node, CNode); -      return; -    }      }      break;    } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 5785440..007074c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -12213,7 +12213,7 @@ static int matchShuffleAsShift(MVT &ShiftVT, unsigned &Opcode,      MVT ShiftSVT = MVT::getIntegerVT(ScalarSizeInBits * Scale);      ShiftVT = ByteShift ? MVT::getVectorVT(MVT::i8, SizeInBits / 8)                          : MVT::getVectorVT(ShiftSVT, Size / Scale); -    return (int)ShiftAmt; +    return ShiftAmt;    };    // SSE/AVX supports logical shifts up to 64-bit integers - so we can just @@ -27946,67 +27946,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,        return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC,                           Operation.getValue(1));      } -    case Intrinsic::x86_t2rpntlvwz0rs_internal: -    case Intrinsic::x86_t2rpntlvwz0rst1_internal: -    case Intrinsic::x86_t2rpntlvwz1rs_internal: -    case Intrinsic::x86_t2rpntlvwz1rst1_internal: -    case Intrinsic::x86_t2rpntlvwz0_internal: -    case Intrinsic::x86_t2rpntlvwz0t1_internal: -    case Intrinsic::x86_t2rpntlvwz1_internal: -    case Intrinsic::x86_t2rpntlvwz1t1_internal: { -      auto *X86MFI = DAG.getMachineFunction().getInfo<X86MachineFunctionInfo>(); -      X86MFI->setAMXProgModel(AMXProgModelEnum::ManagedRA); -      unsigned IntNo = Op.getConstantOperandVal(1); -      unsigned Opc = 0; -      switch (IntNo) { -      default: -        llvm_unreachable("Unexpected intrinsic!"); -      case Intrinsic::x86_t2rpntlvwz0_internal: -        Opc = X86::PT2RPNTLVWZ0V; -        break; -      case Intrinsic::x86_t2rpntlvwz0t1_internal: -        Opc = X86::PT2RPNTLVWZ0T1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1_internal: -        Opc = X86::PT2RPNTLVWZ1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1t1_internal: -        Opc = X86::PT2RPNTLVWZ1T1V; -        break; -      case Intrinsic::x86_t2rpntlvwz0rs_internal: -        Opc = X86::PT2RPNTLVWZ0RSV; -        break; -      case Intrinsic::x86_t2rpntlvwz0rst1_internal: -        Opc = X86::PT2RPNTLVWZ0RST1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1rs_internal: -        Opc = X86::PT2RPNTLVWZ1RSV; -        break; -      case Intrinsic::x86_t2rpntlvwz1rst1_internal: -        Opc = X86::PT2RPNTLVWZ1RST1V; -        break; -      } - -      SDLoc DL(Op); -      SDVTList VTs = DAG.getVTList(MVT::Untyped, MVT::Other); - -      SDValue Ops[] = {Op.getOperand(2),                       // Row -                       Op.getOperand(3),                       // Col0 -                       Op.getOperand(4),                       // Col1 -                       Op.getOperand(5),                       // Base -                       DAG.getTargetConstant(1, DL, MVT::i8),  // Scale -                       Op.getOperand(6),                       // Index -                       DAG.getTargetConstant(0, DL, MVT::i32), // Disp -                       DAG.getRegister(0, MVT::i16),           // Segment -                       Op.getOperand(0)};                      // Chain - -      MachineSDNode *Res = DAG.getMachineNode(Opc, DL, VTs, Ops); -      SDValue Res0 = DAG.getTargetExtractSubreg(X86::sub_t0, DL, MVT::x86amx, -                                                SDValue(Res, 0)); -      SDValue Res1 = DAG.getTargetExtractSubreg(X86::sub_t1, DL, MVT::x86amx, -                                                SDValue(Res, 0)); -      return DAG.getMergeValues({Res0, Res1, SDValue(Res, 1)}, DL); -    }      case Intrinsic::x86_atomic_bts_rm:      case Intrinsic::x86_atomic_btc_rm:      case Intrinsic::x86_atomic_btr_rm: { @@ -37745,10 +37684,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,      assert (Imm < 8 && "Illegal tmm index");      return X86::TMM0 + Imm;    }; -  auto TMMImmToTMMPair = [](unsigned Imm) { -    assert(Imm < 8 && "Illegal tmm pair index."); -    return X86::TMM0_TMM1 + Imm / 2; -  };    switch (MI.getOpcode()) {    default:      llvm_unreachable("Unexpected instr type to insert"); @@ -38129,53 +38064,25 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,    case X86::PTDPBHF8PS:    case X86::PTDPHBF8PS:    case X86::PTDPHF8PS: -  case X86::PTTDPBF16PS: -  case X86::PTTDPFP16PS: -  case X86::PTTCMMIMFP16PS: -  case X86::PTTCMMRLFP16PS: -  case X86::PTCONJTCMMIMFP16PS: -  case X86::PTMMULTF32PS: -  case X86::PTTMMULTF32PS: { +  case X86::PTMMULTF32PS: {      unsigned Opc;      switch (MI.getOpcode()) {      default: llvm_unreachable("illegal opcode!"); +      // clang-format off      case X86::PTDPBSSD: Opc = X86::TDPBSSD; break;      case X86::PTDPBSUD: Opc = X86::TDPBSUD; break;      case X86::PTDPBUSD: Opc = X86::TDPBUSD; break;      case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;      case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;      case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break; -    case X86::PTCMMIMFP16PS: -      Opc = X86::TCMMIMFP16PS; -      break; -    case X86::PTCMMRLFP16PS: -      Opc = X86::TCMMRLFP16PS; -      break; +    case X86::PTCMMIMFP16PS: Opc = X86::TCMMIMFP16PS; break; +    case X86::PTCMMRLFP16PS: Opc = X86::TCMMRLFP16PS; break;      case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;      case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;      case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;      case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break; -    case X86::PTTDPBF16PS: -      Opc = X86::TTDPBF16PS; -      break; -    case X86::PTTDPFP16PS: -      Opc = X86::TTDPFP16PS; -      break; -    case X86::PTTCMMIMFP16PS: -      Opc = X86::TTCMMIMFP16PS; -      break; -    case X86::PTTCMMRLFP16PS: -      Opc = X86::TTCMMRLFP16PS; -      break; -    case X86::PTCONJTCMMIMFP16PS: -      Opc = X86::TCONJTCMMIMFP16PS; -      break; -    case X86::PTMMULTF32PS: -      Opc = X86::TMMULTF32PS; -      break; -    case X86::PTTMMULTF32PS: -      Opc = X86::TTMMULTF32PS; -      break; +    case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break; +      // clang-format on      }      MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc)); @@ -38246,70 +38153,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,      MI.eraseFromParent(); // The pseudo is gone now.      return BB;    } -  case X86::PT2RPNTLVWZ0: -  case X86::PT2RPNTLVWZ0T1: -  case X86::PT2RPNTLVWZ1: -  case X86::PT2RPNTLVWZ1T1: -  case X86::PT2RPNTLVWZ0RS: -  case X86::PT2RPNTLVWZ0RST1: -  case X86::PT2RPNTLVWZ1RS: -  case X86::PT2RPNTLVWZ1RST1: { -    const DebugLoc &DL = MI.getDebugLoc(); -    unsigned Opc; -#define GET_EGPR_IF_ENABLED(OPC) (Subtarget.hasEGPR() ? OPC##_EVEX : OPC) -    switch (MI.getOpcode()) { -    default: -      llvm_unreachable("Unexpected instruction!"); -    case X86::PT2RPNTLVWZ0: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0); -      break; -    case X86::PT2RPNTLVWZ0T1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1); -      break; -    case X86::PT2RPNTLVWZ1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1); -      break; -    case X86::PT2RPNTLVWZ1T1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1); -      break; -    case X86::PT2RPNTLVWZ0RS: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS); -      break; -    case X86::PT2RPNTLVWZ0RST1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1); -      break; -    case X86::PT2RPNTLVWZ1RS: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS); -      break; -    case X86::PT2RPNTLVWZ1RST1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1); -      break; -    } -#undef GET_EGPR_IF_ENABLED -    MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); -    MIB.addReg(TMMImmToTMMPair(MI.getOperand(0).getImm()), RegState::Define); - -    MIB.add(MI.getOperand(1)); // base -    MIB.add(MI.getOperand(2)); // scale -    MIB.add(MI.getOperand(3)); // index -    MIB.add(MI.getOperand(4)); // displacement -    MIB.add(MI.getOperand(5)); // segment -    MI.eraseFromParent();      // The pseudo is gone now. -    return BB; -  } -  case X86::PTTRANSPOSED: -  case X86::PTCONJTFP16: { -    const DebugLoc &DL = MI.getDebugLoc(); -    unsigned Opc = MI.getOpcode() == X86::PTTRANSPOSED ? X86::TTRANSPOSED -                                                       : X86::TCONJTFP16; - -    MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); -    MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); -    MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); - -    MI.eraseFromParent(); // The pseudo is gone now. -    return BB; -  }    case X86::PTCVTROWPS2BF16Hrri:    case X86::PTCVTROWPS2BF16Lrri:    case X86::PTCVTROWPS2PHHrri: @@ -48778,15 +48621,19 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,        SDValue BC0 = peekThroughBitcasts(Op0);        if (BC0.getOpcode() == X86ISD::PCMPEQ &&            ISD::isBuildVectorAllZeros(BC0.getOperand(1).getNode())) { -        SDLoc DL(EFLAGS);          CC = (CC == X86::COND_B ? X86::COND_E : X86::COND_NE); -        SDValue X = DAG.getBitcast(OpVT, BC0.getOperand(0)); -        return DAG.getNode(EFLAGS.getOpcode(), DL, VT, X, X); +        SDValue X = DAG.getBitcast(OpVT, DAG.getFreeze(BC0.getOperand(0))); +        return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, X, X);        }      }    }    if (CC == X86::COND_E || CC == X86::COND_NE) { +    // Canonicalize constant to RHS if we're just using ZF. +    if (Op0 != Op1 && DAG.isConstantIntBuildVectorOrConstantInt(Op0) && +        !DAG.isConstantIntBuildVectorOrConstantInt(Op1)) +      return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op1, Op0); +      // TESTZ(X,~Y) == TESTC(Y,X)      if (SDValue NotOp1 = IsNOT(Op1, DAG)) {        CC = (CC == X86::COND_E ? X86::COND_B : X86::COND_AE); @@ -48832,7 +48679,7 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,                MVT FloatSVT = MVT::getFloatingPointVT(EltBits);                MVT FloatVT =                    MVT::getVectorVT(FloatSVT, OpVT.getSizeInBits() / EltBits); -              Res = DAG.getBitcast(FloatVT, Res); +              Res = DAG.getBitcast(FloatVT, DAG.getFreeze(Res));                return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Res, Res);              } else if (EltBits == 16) {                MVT MovmskVT = BCVT.is128BitVector() ? MVT::v16i8 : MVT::v32i8; @@ -48850,13 +48697,31 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,        }      } -    // TESTZ(-1,X) == TESTZ(X,X) -    if (ISD::isBuildVectorAllOnes(Op0.getNode())) -      return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op1, Op1); -      // TESTZ(X,-1) == TESTZ(X,X) -    if (ISD::isBuildVectorAllOnes(Op1.getNode())) +    if (ISD::isBuildVectorAllOnes(Op1.getNode())) { +      Op0 = DAG.getFreeze(Op0);        return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op0, Op0); +    } + +    // Attempt to convert PTESTZ(X,SIGNMASK) -> VTESTPD/PSZ(X,X) on AVX targets. +    if (EFLAGS.getOpcode() == X86ISD::PTEST && Subtarget.hasAVX()) { +      KnownBits KnownOp1 = DAG.computeKnownBits(Op1); +      assert(KnownOp1.getBitWidth() == 64 && +             "Illegal PTEST vector element width"); +      if (KnownOp1.isConstant()) { +        const APInt &Mask = KnownOp1.getConstant(); +        if (Mask.isSignMask()) { +          MVT FpVT = MVT::getVectorVT(MVT::f64, OpVT.getSizeInBits() / 64); +          Op0 = DAG.getBitcast(FpVT, DAG.getFreeze(Op0)); +          return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Op0, Op0); +        } +        if (Mask.isSplat(32) && Mask.trunc(32).isSignMask()) { +          MVT FpVT = MVT::getVectorVT(MVT::f32, OpVT.getSizeInBits() / 32); +          Op0 = DAG.getBitcast(FpVT, DAG.getFreeze(Op0)); +          return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Op0, Op0); +        } +      } +    }      // TESTZ(OR(LO(X),HI(X)),OR(LO(Y),HI(Y))) -> TESTZ(X,Y)      // TODO: Add COND_NE handling? @@ -53479,6 +53344,105 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,    return SDValue();  } +// Look for a RMW operation that only touches one bit of a larger than legal +// type and fold it to a BTC/BTR/BTS or bit insertion pattern acting on a single +// i32 sub value. +static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL, +                              SelectionDAG &DAG, +                              const X86Subtarget &Subtarget) { +  using namespace SDPatternMatch; + +  // Only handle normal stores and its chain was a matching normal load. +  auto *Ld = dyn_cast<LoadSDNode>(St->getChain()); +  if (!ISD::isNormalStore(St) || !St->isSimple() || !Ld || +      !ISD::isNormalLoad(Ld) || !Ld->isSimple() || +      Ld->getBasePtr() != St->getBasePtr() || +      Ld->getOffset() != St->getOffset()) +    return SDValue(); + +  SDValue LoadVal(Ld, 0); +  SDValue StoredVal = St->getValue(); +  EVT VT = StoredVal.getValueType(); + +  // Only narrow larger than legal scalar integers. +  if (!VT.isScalarInteger() || +      VT.getSizeInBits() <= (Subtarget.is64Bit() ? 64 : 32)) +    return SDValue(); + +  // BTR: X & ~(1 << ShAmt) +  // BTS: X | (1 << ShAmt) +  // BTC: X ^ (1 << ShAmt) +  // +  // BitInsert: (X & ~(1 << ShAmt)) | (InsertBit << ShAmt) +  SDValue InsertBit, ShAmt; +  if (!StoredVal.hasOneUse() || +      !(sd_match(StoredVal, m_And(m_Specific(LoadVal), +                                  m_Not(m_Shl(m_One(), m_Value(ShAmt))))) || +        sd_match(StoredVal, +                 m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) || +        sd_match(StoredVal, +                 m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) || +        sd_match(StoredVal, +                 m_Or(m_And(m_Specific(LoadVal), +                            m_Not(m_Shl(m_One(), m_Value(ShAmt)))), +                      m_Shl(m_Value(InsertBit), m_Deferred(ShAmt)))))) +    return SDValue(); + +  // Ensure the shift amount is in bounds. +  KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); +  if (KnownAmt.getMaxValue().uge(VT.getSizeInBits())) +    return SDValue(); + +  // If we're inserting a bit then it must be the LSB. +  if (InsertBit) { +    KnownBits KnownInsert = DAG.computeKnownBits(InsertBit); +    if (KnownInsert.countMinLeadingZeros() < (VT.getSizeInBits() - 1)) +      return SDValue(); +  } + +  // Split the shift into an alignment shift that moves the active i32 block to +  // the bottom bits for truncation and a modulo shift that can act on the i32. +  EVT AmtVT = ShAmt.getValueType(); +  SDValue AlignAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, +                                 DAG.getSignedConstant(-32LL, DL, AmtVT)); +  SDValue ModuloAmt = +      DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT)); +  ModuloAmt = DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8); + +  // Compute the byte offset for the i32 block that is changed by the RMW. +  // combineTruncate will adjust the load for us in a similar way. +  EVT PtrVT = St->getBasePtr().getValueType(); +  SDValue PtrBitOfs = DAG.getZExtOrTrunc(AlignAmt, DL, PtrVT); +  SDValue PtrByteOfs = DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs, +                                   DAG.getShiftAmountConstant(3, PtrVT, DL)); +  SDValue NewPtr = DAG.getMemBasePlusOffset(St->getBasePtr(), PtrByteOfs, DL, +                                            SDNodeFlags::NoUnsignedWrap); + +  // Reconstruct the BTC/BTR/BTS pattern for the i32 block and store. +  SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt); +  X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); + +  SDValue Mask = DAG.getNode(ISD::SHL, DL, MVT::i32, +                             DAG.getConstant(1, DL, MVT::i32), ModuloAmt); + +  SDValue Res; +  if (InsertBit) { +    SDValue BitMask = +        DAG.getNode(ISD::SHL, DL, MVT::i32, +                    DAG.getZExtOrTrunc(InsertBit, DL, MVT::i32), ModuloAmt); +    Res = +        DAG.getNode(ISD::AND, DL, MVT::i32, X, DAG.getNOT(DL, Mask, MVT::i32)); +    Res = DAG.getNode(ISD::OR, DL, MVT::i32, Res, BitMask); +  } else { +    if (StoredVal.getOpcode() == ISD::AND) +      Mask = DAG.getNOT(DL, Mask, MVT::i32); +    Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask); +  } + +  return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(), +                      Align(), St->getMemOperand()->getFlags()); +} +  static SDValue combineStore(SDNode *N, SelectionDAG &DAG,                              TargetLowering::DAGCombinerInfo &DCI,                              const X86Subtarget &Subtarget) { @@ -53705,6 +53669,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,      }    } +  if (SDValue R = narrowBitOpRMW(St, dl, DAG, Subtarget)) +    return R; +    // Convert store(cmov(load(p), x, CC), p) to cstore(x, p, CC)    //         store(cmov(x, load(p), CC), p) to cstore(x, p, InvertCC)    if ((VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) && @@ -54492,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,  static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,                                 const X86Subtarget &Subtarget,                                 const SDLoc &DL) { +  using namespace SDPatternMatch;    if (!VT.isVector() || !Subtarget.hasSSSE3())      return SDValue(); @@ -54501,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,      return SDValue();    SDValue SSatVal = detectSSatPattern(In, VT); -  if (!SSatVal || SSatVal.getOpcode() != ISD::ADD) +  if (!SSatVal)      return SDValue(); -  // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs -  // of multiplies from even/odd elements. -  SDValue N0 = SSatVal.getOperand(0); -  SDValue N1 = SSatVal.getOperand(1); - -  if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) -    return SDValue(); - -  SDValue N00 = N0.getOperand(0); -  SDValue N01 = N0.getOperand(1); -  SDValue N10 = N1.getOperand(0); -  SDValue N11 = N1.getOperand(1); - +  // See if this is a signed saturation of an ADD, adding pairs of multiplies +  // from even/odd elements, from zero_extend/sign_extend operands. +  //    // TODO: Handle constant vectors and use knownbits/computenumsignbits? -  // Canonicalize zero_extend to LHS. -  if (N01.getOpcode() == ISD::ZERO_EXTEND) -    std::swap(N00, N01); -  if (N11.getOpcode() == ISD::ZERO_EXTEND) -    std::swap(N10, N11); - -  // Ensure we have a zero_extend and a sign_extend. -  if (N00.getOpcode() != ISD::ZERO_EXTEND || -      N01.getOpcode() != ISD::SIGN_EXTEND || -      N10.getOpcode() != ISD::ZERO_EXTEND || -      N11.getOpcode() != ISD::SIGN_EXTEND) +  SDValue N00, N01, N10, N11; +  if (!sd_match(SSatVal, +                m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))), +                      m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11))))))      return SDValue(); -  // Peek through the extends. -  N00 = N00.getOperand(0); -  N01 = N01.getOperand(0); -  N10 = N10.getOperand(0); -  N11 = N11.getOperand(0); -    // Ensure the extend is from vXi8.    if (N00.getValueType().getVectorElementType() != MVT::i8 ||        N01.getValueType().getVectorElementType() != MVT::i8 || @@ -54659,8 +54604,9 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,    // truncation, see if we can convert the shift into a pointer offset instead.    // Limit this to normal (non-ext) scalar integer loads.    if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL && -      Src.hasOneUse() && Src.getOperand(0).hasOneUse() && -      ISD::isNormalLoad(Src.getOperand(0).getNode())) { +      Src.hasOneUse() && ISD::isNormalLoad(Src.getOperand(0).getNode()) && +      (Src.getOperand(0).hasOneUse() || +       !DAG.getTargetLoweringInfo().isOperationLegal(ISD::LOAD, SrcVT))) {      auto *Ld = cast<LoadSDNode>(Src.getOperand(0));      if (Ld->isSimple() && VT.isByteSized() &&          isPowerOf2_64(VT.getSizeInBits())) { @@ -54668,9 +54614,11 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,        KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);        // Check the shift amount is byte aligned.        // Check the truncation doesn't use any shifted in (zero) top bits. +      // Check the shift amount doesn't depend on the original load.        if (KnownAmt.countMinTrailingZeros() >= 3 &&            KnownAmt.getMaxValue().ule(SrcVT.getSizeInBits() - -                                     VT.getSizeInBits())) { +                                     VT.getSizeInBits()) && +          !Ld->isPredecessorOf(ShAmt.getNode())) {          EVT PtrVT = Ld->getBasePtr().getValueType();          SDValue PtrBitOfs = DAG.getZExtOrTrunc(ShAmt, DL, PtrVT);          SDValue PtrByteOfs = @@ -56458,6 +56406,7 @@ static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,  static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,                              TargetLowering::DAGCombinerInfo &DCI,                              const X86Subtarget &Subtarget) { +  using namespace SDPatternMatch;    const ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();    const SDValue LHS = N->getOperand(0);    const SDValue RHS = N->getOperand(1); @@ -56516,6 +56465,37 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,        if (SDValue AndN = MatchAndCmpEq(RHS, LHS))          return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC); +      // If we're performing a bit test on a larger than legal type, attempt +      // to (aligned) shift down the value to the bottom 32-bits and then +      // perform the bittest on the i32 value. +      // ICMP_ZERO(AND(X,SHL(1,IDX))) +      // --> ICMP_ZERO(AND(TRUNC(SRL(X,AND(IDX,-32))),SHL(1,AND(IDX,31)))) +      if (isNullConstant(RHS) && +          OpVT.getScalarSizeInBits() > (Subtarget.is64Bit() ? 64 : 32)) { +        SDValue X, ShAmt; +        if (sd_match(LHS, m_OneUse(m_And(m_Value(X), +                                         m_Shl(m_One(), m_Value(ShAmt)))))) { +          // Only attempt this if the shift amount is known to be in bounds. +          KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); +          if (KnownAmt.getMaxValue().ult(OpVT.getScalarSizeInBits())) { +            EVT AmtVT = ShAmt.getValueType(); +            SDValue AlignAmt = +                DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, +                            DAG.getSignedConstant(-32LL, DL, AmtVT)); +            SDValue ModuloAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, +                                            DAG.getConstant(31, DL, AmtVT)); +            SDValue Mask = DAG.getNode( +                ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32), +                DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8)); +            X = DAG.getNode(ISD::SRL, DL, OpVT, X, AlignAmt); +            X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); +            X = DAG.getNode(ISD::AND, DL, MVT::i32, X, Mask); +            return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, MVT::i32), +                                CC); +          } +        } +      } +        // cmpeq(trunc(x),C) --> cmpeq(x,C)        // cmpne(trunc(x),C) --> cmpne(x,C)        // iff x upper bits are zero. diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td index 69a5115..522782a 100644 --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -338,188 +338,6 @@ let Predicates = [HasAMXFP8, In64BitMode] in {    }  } -let Predicates = [HasAMXTILE, In64BitMode], isPseudo = true, SchedRW = [WriteSystem] in { -  let mayStore = 1 in -  def PTILEPAIRSTORE : PseudoI<(outs), (ins opaquemem:$src1, TILEPair:$src2), []>; -  let mayLoad = 1 in -  def PTILEPAIRLOAD : PseudoI<(outs TILEPair:$dst), (ins opaquemem:$src), []>; -} - -multiclass T2RPNTLVW_Base<bits<8> op1, bits<8> op2, string rs, string suffix> { -  def Z0#rs#suffix    : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz0" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PS; -  def Z0#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz0" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PS; -  def Z1#rs#suffix    : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz1" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PD; -  def Z1#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz1" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PD; -} - -let Predicates = [HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "">, T8, VEX; - -let Predicates = [HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "_EVEX">, T8, EVEX, NoCD8; - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "">, T_MAP5, VEX; - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "_EVEX">, T_MAP5, EVEX, NoCD8; - -let Predicates = [HasAMXTRANSPOSE, In64BitMode] in { -  let SchedRW = [WriteSystem] in { -    def TTRANSPOSED : I<0x5f, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), -                        "ttransposed\t{$src, $dst|$dst, $src}", []>, VEX, T8, XS; -    let isPseudo = true in { -      def PT2RPNTLVWZ0V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ0T1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ1T1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -    } - -    def PTTRANSPOSEDV : PseudoI<(outs TILE:$dst), -                                (ins GR16:$src1, GR16:$src2, TILE:$src), -                                [(set TILE: $dst, -                                 (int_x86_ttransposed_internal GR16:$src1, GR16:$src2, -                                  TILE:$src))]>; - -    let usesCustomInserter = 1 in { -      def PT2RPNTLVWZ0 : PseudoI<(outs), (ins u8imm:$dst, -                                 sibmem:$src1), []>; -      def PT2RPNTLVWZ0T1 : PseudoI<(outs), (ins u8imm:$dst, -                                   sibmem:$src1), []>; -      def PT2RPNTLVWZ1 : PseudoI<(outs), (ins u8imm:$dst, -                                 sibmem:$src1), []>; -      def PT2RPNTLVWZ1T1 : PseudoI<(outs), (ins u8imm:$dst, -                                   sibmem:$src1), []>; -      def PTTRANSPOSED : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), -                                 [(int_x86_ttransposed timm:$dst, timm:$src)]>; -    } -  } -} // HasAMXTILE, HasAMXTRANSPOSE - -let Predicates = [HasAMXBF16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in -    def TTDPBF16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), -                       (ins TILE:$src1, TILE:$src2, TILE:$src3), -                       "ttdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                       []>, VEX, VVVV, T8,XS; -  let Constraints = "$src4 = $dst" in -    def PTTDPBF16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                GR16:$src2, GR16:$src3, TILE:$src4, -                                TILE:$src5, TILE:$src6), -                                [(set TILE: $dst, -                                  (int_x86_ttdpbf16ps_internal GR16:$src1, GR16:$src2, -                                   GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  let usesCustomInserter = 1 in -    def PTTDPBF16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                              [(int_x86_ttdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>; -} - -let Predicates = [HasAMXFP16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in -    def TTDPFP16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), -                       (ins TILE:$src1, TILE:$src2, TILE:$src3), -                       "ttdpfp16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                       []>, VEX, VVVV, T8,XD; -  let Constraints = "$src4 = $dst" in -    def PTTDPFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                GR16:$src2, GR16:$src3, TILE:$src4, -                                TILE:$src5, TILE:$src6), -                                [(set TILE: $dst, -                                  (int_x86_ttdpfp16ps_internal GR16:$src1, GR16:$src2, -                                   GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  let usesCustomInserter = 1 in -    def PTTDPFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                              [(int_x86_ttdpfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -} - -let Predicates = [HasAMXCOMPLEX, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in { -    def TTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                          (ins TILE:$src1, TILE:$src2, TILE:$src3), -                          "ttcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                          []>, VEX, VVVV, T8,XD; -    def TTCMMRLFP16PS: I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                         (ins TILE:$src1, TILE:$src2, TILE:$src3), -                         "ttcmmrlfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                         []>, VEX, VVVV, T8,XS; -    def TCONJTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                          (ins TILE:$src1, TILE:$src2, TILE:$src3), -                          "tconjtcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                          []>, VEX, VVVV, WIG, T8,PS; -  } -  def TCONJTFP16 : I<0x6b, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), -                     "tconjtfp16\t{$src, $dst|$dst, $src}", []>, VEX, T8,PD; - -  let Constraints = "$src4 = $dst" in { -    def PTTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                  GR16:$src2, GR16:$src3, TILE:$src4, -                                  TILE:$src5, TILE:$src6), -                                  [(set TILE: $dst, -                                    (int_x86_ttcmmimfp16ps_internal GR16:$src1, GR16:$src2, -                                     GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -    def PTTCMMRLFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                  GR16:$src2, GR16:$src3, TILE:$src4, -                                  TILE:$src5, TILE:$src6), -                                  [(set TILE: $dst, -                                    (int_x86_ttcmmrlfp16ps_internal GR16:$src1, GR16:$src2, -                                     GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -    def PTCONJTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                      GR16:$src2, GR16:$src3, TILE:$src4, -                                      TILE:$src5, TILE:$src6), -                                      [(set TILE: $dst, -                                        (int_x86_tconjtcmmimfp16ps_internal GR16:$src1, GR16:$src2, -                                         GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  } -  def PTCONJTFP16V : PseudoI<(outs TILE:$dst), (ins GR16:$src1, GR16:$src2, TILE:$src3), -                             [(set TILE: $dst, (int_x86_tconjtfp16_internal GR16:$src1, GR16:$src2, TILE:$src3))]>; - -  let usesCustomInserter = 1 in { -    def PTTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                 [(int_x86_ttcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTTCMMRLFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                 [(int_x86_ttcmmrlfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTCONJTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                     [(int_x86_tconjtcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTCONJTFP16 : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), -                              [(int_x86_tconjtfp16 timm:$dst, timm:$src)]>; -  } -} - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let isPseudo = true in { -    def PT2RPNTLVWZ0RSV   : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ0RST1V : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ1RSV   : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ1RST1V : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -  } -  let  usesCustomInserter = 1 in { -    def PT2RPNTLVWZ0RS   : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ0RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ1RS   : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ1RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -  } -} // HasAMXMOVRS, HasAMXTRANSPOSE -  multiclass TILELOADDRS_Base<string suffix> {    def suffix    : I<0x4a, MRMSrcMemFSIB, (outs TILE:$dst), (ins sibmem:$src1),                      "tileloaddrs\t{$src1, $dst|$dst, $src1}", []>, T8, XD; @@ -721,29 +539,3 @@ let Predicates = [HasAMXTF32, In64BitMode] in {      }    } // SchedRW = [WriteSystem]  } // HasAMXTF32 - -let Predicates = [HasAMXTF32, HasAMXTRANSPOSE, In64BitMode] in { -  let SchedRW = [WriteSystem] in { -    let Constraints = "$src1 = $dst" in { -      def TTMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst), -                         (ins TILE:$src1, TILE:$src2, TILE:$src3), -                         "ttmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                         []>, VEX, VVVV, T8, PS; -    } -    let Constraints = "$src4 = $dst" in { -      def PTTMMULTF32PSV : PseudoI<(outs TILE:$dst), -                                   (ins GR16:$src1, GR16:$src2, GR16:$src3, -                                    TILE:$src4, TILE:$src5, TILE:$src6), -                                   [(set TILE:$dst, -                                     (int_x86_ttmmultf32ps_internal GR16:$src1, -                                      GR16:$src2, GR16:$src3, TILE:$src4, -                                      TILE:$src5, TILE:$src6))]>; -    } -    let usesCustomInserter = 1 in { -      def PTTMMULTF32PS : PseudoI<(outs), -                                  (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                  [(int_x86_ttmmultf32ps timm:$src1, timm:$src2, -                                    timm:$src3)]>; -    } -  } // SchedRW = [WriteSystem] -} // HasAMXTF32, HasAMXTRANSPOSE diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index 5c23f91..6b2a7a4 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -4544,11 +4544,6 @@ static unsigned getLoadStoreRegOpcode(Register Reg,      return Load ? GET_EGPR_IF_ENABLED(X86::TILELOADD)                  : GET_EGPR_IF_ENABLED(X86::TILESTORED);  #undef GET_EGPR_IF_ENABLED -  case 2048: -    assert(X86::TILEPAIRRegClass.hasSubClassEq(RC) && -           "Unknown 2048-byte regclass"); -    assert(STI.hasAMXTILE() && "Using 2048-bit register requires AMX-TILE"); -    return Load ? X86::PTILEPAIRLOAD : X86::PTILEPAIRSTORE;    }  } @@ -4743,8 +4738,6 @@ static bool isAMXOpcode(unsigned Opc) {    case X86::TILESTORED:    case X86::TILELOADD_EVEX:    case X86::TILESTORED_EVEX: -  case X86::PTILEPAIRLOAD: -  case X86::PTILEPAIRSTORE:      return true;    }  } @@ -4757,8 +4750,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB,    default:      llvm_unreachable("Unexpected special opcode!");    case X86::TILESTORED: -  case X86::TILESTORED_EVEX: -  case X86::PTILEPAIRSTORE: { +  case X86::TILESTORED_EVEX: {      // tilestored %tmm, (%sp, %idx)      MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();      Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); @@ -4772,8 +4764,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB,      break;    }    case X86::TILELOADD: -  case X86::TILELOADD_EVEX: -  case X86::PTILEPAIRLOAD: { +  case X86::TILELOADD_EVEX: {      // tileloadd (%sp, %idx), %tmm      MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();      Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); diff --git a/llvm/lib/Target/X86/X86InstrOperands.td b/llvm/lib/Target/X86/X86InstrOperands.td index 5207eca..6ba07f7 100644 --- a/llvm/lib/Target/X86/X86InstrOperands.td +++ b/llvm/lib/Target/X86/X86InstrOperands.td @@ -536,10 +536,3 @@ def VK8Pair : RegisterOperand<VK8PAIR, "printVKPair"> {  def VK16Pair : RegisterOperand<VK16PAIR, "printVKPair"> {    let ParserMatchClass = VK16PairAsmOperand;  } - -let RenderMethod = "addTILEPairOperands" in -  def TILEPairAsmOperand : AsmOperandClass { let Name = "TILEPair"; } - -def TILEPair : RegisterOperand<TILEPAIR, "printTILEPair"> { -  let ParserMatchClass = TILEPairAsmOperand; -} diff --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td index c20bb05..98104a6f 100644 --- a/llvm/lib/Target/X86/X86InstrPredicates.td +++ b/llvm/lib/Target/X86/X86InstrPredicates.td @@ -183,7 +183,6 @@ def HasAMXINT8   : Predicate<"Subtarget->hasAMXINT8()">;  def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">;  def HasAMXFP8    : Predicate<"Subtarget->hasAMXFP8()">;  def HasAMXMOVRS  : Predicate<"Subtarget->hasAMXMOVRS()">; -def HasAMXTRANSPOSE : Predicate<"Subtarget->hasAMXTRANSPOSE()">;  def HasAMXAVX512 : Predicate<"Subtarget->hasAMXAVX512()">;  def HasAMXTF32   : Predicate<"Subtarget->hasAMXTF32()">;  def HasUINTR     : Predicate<"Subtarget->hasUINTR()">; diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 8ffd454..2fc5d38 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -74,22 +74,6 @@ static bool isAMXCast(Instruction *II) {           match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));  } -// Some instructions may return more than one tiles. -// e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal -static unsigned getNumDefTiles(IntrinsicInst *II) { -  Type *Ty = II->getType(); -  if (Ty->isX86_AMXTy()) -    return 1; - -  unsigned Num = 0; -  for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) { -    Type *STy = Ty->getContainedType(i); -    if (STy->isX86_AMXTy()) -      Num++; -  } -  return Num; -} -  static bool isAMXIntrinsic(Value *I) {    auto *II = dyn_cast<IntrinsicInst>(I);    if (!II) @@ -98,7 +82,7 @@ static bool isAMXIntrinsic(Value *I) {      return false;    // Check if return type or parameter is x86_amx. If it is x86_amx    // the intrinsic must be x86 amx intrinsics. -  if (getNumDefTiles(II) > 0) +  if (II->getType()->isX86_AMXTy())      return true;    for (Value *V : II->args()) {      if (V->getType()->isX86_AMXTy()) @@ -137,27 +121,7 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {    llvm_unreachable("No terminator in the entry block!");  } -class ShapeCalculator { -private: -  const TargetMachine *TM = nullptr; - -  // In AMX intrinsics we let Shape = {Row, Col}, but the -  // RealCol = Col / ElementSize. We may use the RealCol -  // as a new Row for other new created AMX intrinsics. -  std::map<Value *, Value *> Col2Row, Row2Col; - -public: -  ShapeCalculator(const TargetMachine *TargetM) : TM(TargetM) {} -  std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); -  std::pair<Value *, Value *> getShape(PHINode *Phi); -  Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); -  Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity); -}; - -Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V, -                                      unsigned Granularity) { -  if (auto It = Col2Row.find(V); It != Col2Row.end()) -    return It->second; +static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) {    IRBuilder<> Builder(II);    Value *RealRow = nullptr;    if (isa<ConstantInt>(V)) @@ -186,47 +150,16 @@ Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V,          getFirstNonAllocaInTheEntryBlock(*II->getFunction()));      RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));    } -  Col2Row[V] = RealRow;    return RealRow;  } -Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V, -                                      unsigned Granularity) { -  if (auto It = Row2Col.find(V); It != Row2Col.end()) -    return It->second; -  IRBuilder<> Builder(II); -  Value *RealCol = nullptr; -  if (isa<ConstantInt>(V)) -    RealCol = -        Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity); -  else if (isa<Instruction>(V)) { -    Builder.SetInsertPoint(cast<Instruction>(V)); -    RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity)); -    cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V)); -  } else { -    // When it is not a const value and it is a function argument, we create -    // Row at the entry bb. -    IRBuilder<> NewBuilder( -        getFirstNonAllocaInTheEntryBlock(*II->getFunction())); -    RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity)); -  } -  Row2Col[V] = RealCol; -  return RealCol; -} -  // TODO: Refine the row and col-in-bytes of tile to row and col of matrix. -std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, -                                                      unsigned OpNo) { -  (void)TM; +std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {    IRBuilder<> Builder(II);    Value *Row = nullptr, *Col = nullptr;    switch (II->getIntrinsicID()) {    default:      llvm_unreachable("Expect amx intrinsics"); -  case Intrinsic::x86_t2rpntlvwz0_internal: -  case Intrinsic::x86_t2rpntlvwz0t1_internal: -  case Intrinsic::x86_t2rpntlvwz1_internal: -  case Intrinsic::x86_t2rpntlvwz1t1_internal:    case Intrinsic::x86_tileloadd64_internal:    case Intrinsic::x86_tileloaddt164_internal:    case Intrinsic::x86_tilestored64_internal: @@ -271,13 +204,6 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,      }      break;    } -  case Intrinsic::x86_ttransposed_internal: -  case Intrinsic::x86_tconjtfp16_internal: { -    assert((OpNo == 2) && "Illegal Operand Number."); -    Row = getRowFromCol(II, II->getArgOperand(1), 4); -    Col = getColFromRow(II, II->getArgOperand(0), 4); -    break; -  }    case Intrinsic::x86_tcvtrowd2ps_internal:    case Intrinsic::x86_tcvtrowps2bf16h_internal:    case Intrinsic::x86_tcvtrowps2bf16l_internal: @@ -289,34 +215,12 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,      Col = II->getArgOperand(1);      break;    } -  case Intrinsic::x86_ttdpbf16ps_internal: -  case Intrinsic::x86_ttdpfp16ps_internal: -  case Intrinsic::x86_ttcmmimfp16ps_internal: -  case Intrinsic::x86_ttcmmrlfp16ps_internal: -  case Intrinsic::x86_tconjtcmmimfp16ps_internal: -  case Intrinsic::x86_ttmmultf32ps_internal: { -    switch (OpNo) { -    case 3: -      Row = II->getArgOperand(0); -      Col = II->getArgOperand(1); -      break; -    case 4: -      Row = getRowFromCol(II, II->getArgOperand(2), 4); -      Col = getColFromRow(II, II->getArgOperand(0), 4); -      break; -    case 5: -      Row = getRowFromCol(II, II->getArgOperand(2), 4); -      Col = II->getArgOperand(1); -      break; -    } -    break; -  }    }    return std::make_pair(Row, Col);  } -std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) { +static std::pair<Value *, Value *> getShape(PHINode *Phi) {    Use &U = *(Phi->use_begin());    unsigned OpNo = U.getOperandNo();    User *V = U.getUser(); @@ -349,15 +253,14 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) {  namespace {  class X86LowerAMXType {    Function &Func; -  ShapeCalculator *SC;    // In AMX intrinsics we let Shape = {Row, Col}, but the    // RealCol = Col / ElementSize. We may use the RealCol    // as a new Row for other new created AMX intrinsics. -  std::map<Value *, Value *> Col2Row, Row2Col; +  std::map<Value *, Value *> Col2Row;  public: -  X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {} +  X86LowerAMXType(Function &F) : Func(F) {}    bool visit();    void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);    void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); @@ -374,7 +277,7 @@ void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {    Use &U = *(Bitcast->use_begin());    unsigned OpNo = U.getOperandNo();    auto *II = cast<IntrinsicInst>(U.getUser()); -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(Bitcast);    // Use the maximun column as stride.    Value *Stride = Builder.getInt64(64); @@ -454,7 +357,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {      Builder.CreateStore(Src, AllocaAddr);      // TODO we can pick an constant operand for the shape.      Value *Row = nullptr, *Col = nullptr; -    std::tie(Row, Col) = SC->getShape(II, OpNo); +    std::tie(Row, Col) = getShape(II, OpNo);      std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};      Value *NewInst =          Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args); @@ -594,18 +497,11 @@ static Value *getAllocaPos(BasicBlock *BB) {  static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {    assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); -  auto *II = dyn_cast<IntrinsicInst>(TileDef); -  unsigned Idx = 0; -  // Extract tile from multiple tiles' def. -  if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) { -    assert(Extr->hasIndices() && "Tile extract miss index!"); -    Idx = Extr->getIndices()[0]; -    II = cast<IntrinsicInst>(Extr->getOperand(0)); -  } +  auto *II = cast<IntrinsicInst>(TileDef);    assert(II && "Not tile intrinsic!"); -  Value *Row = II->getOperand(Idx); -  Value *Col = II->getOperand(Idx + 1); +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1);    BasicBlock *BB = TileDef->getParent();    BasicBlock::iterator Iter = TileDef->getIterator(); @@ -624,20 +520,14 @@ static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {    // Get tile shape.    IntrinsicInst *II = nullptr; -  unsigned Idx = 0;    if (IsPHI) {      Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);      II = cast<IntrinsicInst>(PhiOp); -  } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) { -    // Extract tile from multiple tiles' def. -    assert(Extr->hasIndices() && "Tile extract miss index!"); -    Idx = Extr->getIndices()[0]; -    II = cast<IntrinsicInst>(Extr->getOperand(0));    } else {      II = cast<IntrinsicInst>(V);    } -  Value *Row = II->getOperand(Idx); -  Value *Col = II->getOperand(Idx + 1); +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1);    Instruction *UserI = cast<Instruction>(U.getUser());    IRBuilder<> Builder(UserI); @@ -848,12 +738,10 @@ namespace {  class X86LowerAMXCast {    Function &Func; -  ShapeCalculator *SC;    std::unique_ptr<DominatorTree> DT;  public: -  X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC) -      : Func(F), SC(ShapeC), DT(nullptr) {} +  X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}    bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);    bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);    bool combineTilezero(IntrinsicInst *Cast); @@ -932,7 +820,7 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(          if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())            return false;          Value *Row = nullptr, *Col = nullptr; -        std::tie(Row, Col) = SC->getShape(OldPN); +        std::tie(Row, Col) = getShape(OldPN);          // TODO: If it is not constant the Row and Col must domoniate tilezero          // that we are going to create.          if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) @@ -1063,19 +951,6 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(    return true;  } -static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx, -                                       bool IsRow) { -  if (!isAMXIntrinsic(Inst)) -    return nullptr; - -  auto *II = cast<IntrinsicInst>(Inst); -  if (IsRow) -    return II->getOperand(0); - -  assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!"); -  return II->getOperand(ShapeIdx + 1); -} -  // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)  // store <256 x i32> %43, <256 x i32>* %p, align 64  // --> @@ -1090,38 +965,13 @@ bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {    if (!Tile->hasOneUse())      return false; -  // We don't fetch shape from tilestore, we only get shape from tiledef, -  // so we can set the max tile shape to tilestore for special cases. +  auto *II = cast<IntrinsicInst>(Tile); +  // Tile is output from AMX intrinsic. The first operand of the +  // intrinsic is row, the second operand of the intrinsic is column. +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1); +    IRBuilder<> Builder(ST); -  Value *Row = nullptr; -  Value *Col = nullptr; - -  if (isAMXIntrinsic(Tile)) { -    auto *II = cast<IntrinsicInst>(Tile); -    // Tile is output from AMX intrinsic. The first operand of the -    // intrinsic is row, the second operand of the intrinsic is column. -    Row = II->getOperand(0); -    Col = II->getOperand(1); -  } else { -    // Now we supported multi-tiles value in structure, so we may get tile -    // from extracting multi-tiles structure. -    // For example: -    // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1, -    //      i16 %2, i16 %3, i8* %4, i64 %5) -    // %7 = extractvalue { x86_amx, x86_amx } %6, 0 -    // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7) -    // store <256 x i32> %8, <256 x i32>* %0, align 1024 -    // -    // TODO: Currently we only handle extractvalue case, enhance me for other -    // cases if possible. -    auto *II = cast<ExtractValueInst>(Tile); -    assert(II && "We meet unhandle source in fetching tile value!"); -    unsigned ShapeIdx = II->getIndices()[0]; -    Value *Tiles = II->getOperand(0); -    Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true); -    Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false); -  } -  assert(Row && Col && "Shape got failed!");    // Stride should be equal to col(measured by bytes)    Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); @@ -1146,7 +996,7 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {    // shape information through def-use chain.    if (!isAMXIntrinsic(II))      return false; -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(LD);    // Stride should be equal to col(measured by bytes)    Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); @@ -1189,7 +1039,7 @@ bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {    if (!isAMXIntrinsic(II))      return false; -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(Cast);    Value *NewInst = @@ -1384,7 +1234,7 @@ bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {      Builder.CreateStore(Src, AllocaAddr);      // TODO we can pick an constant operand for the shape.      Value *Row = nullptr, *Col = nullptr; -    std::tie(Row, Col) = SC->getShape(II, OpNo); +    std::tie(Row, Col) = getShape(II, OpNo);      std::array<Value *, 4> Args = {          Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};      Value *NewInst = @@ -1445,14 +1295,13 @@ bool lowerAmxType(Function &F, const TargetMachine *TM,      return false;    bool C = false; -  ShapeCalculator SC(TM); -  X86LowerAMXCast LAC(F, &SC); +  X86LowerAMXCast LAC(F);    C |= LAC.combineAMXcast(TLI);    // There might be remaining AMXcast after combineAMXcast and they should be    // handled elegantly.    C |= LAC.transformAllAMXCast(); -  X86LowerAMXType LAT(F, &SC); +  X86LowerAMXType LAT(F);    C |= LAT.visit();    // Prepare for fast register allocation at O0. diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp index 2a1c499..8a1d00d 100644 --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -141,15 +141,10 @@ class X86PreTileConfig : public MachineFunctionPass {      if (!MO.isReg() || !MO.getReg().isVirtual())        return false; -    unsigned Shapes = 0; -    if (MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) -      Shapes = 1; -    if (MRI->getRegClass(MO.getReg())->getID() == X86::TILEPAIRRegClassID) -      Shapes = 2; -    if (!Shapes) +    if (MRI->getRegClass(MO.getReg())->getID() != X86::TILERegClassID)        return false; -    collectShapeInfo(MI, Shapes); +    collectShapeInfo(MI);      return true;    } @@ -165,7 +160,7 @@ class X86PreTileConfig : public MachineFunctionPass {    }    /// Collect the shape def information for later use. -  void collectShapeInfo(MachineInstr &MI, unsigned Shapes); +  void collectShapeInfo(MachineInstr &MI);    /// Try to hoist shapes definded below AMX instructions.    bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { @@ -231,7 +226,7 @@ INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)  INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",                      "Tile Register Pre-configure", false, false) -void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) { +void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {    auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {      MIRef MIR(MI, MBB);      auto &Refs = ShapeBBs[MBB]; @@ -240,10 +235,8 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {        Refs.insert(I, MIR);    }; -  // All shapes have same row in multi-tile operand. -  SmallVector<Register, 8> WorkList; -  for (unsigned I = 1; I < Shapes + 2; ++I) -    WorkList.push_back(MI.getOperand(I).getReg()); +  SmallVector<Register, 8> WorkList( +      {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});    while (!WorkList.empty()) {      Register R = WorkList.pop_back_val();      MachineInstr *DefMI = MRI->getVRegDef(R); @@ -252,13 +245,6 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {      if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)        continue; -    // This happens when column = 0 in multi-tile operand. -    if (DefMI->getOpcode() == X86::COPY) { -      MachineInstr *MI = MRI->getVRegDef(DefMI->getOperand(1).getReg()); -      if (MI && MI->isMoveImmediate()) -        continue; -    } -      if (DefMI->isPHI()) {        for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)          if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp index 76979e3..72f3813 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -597,10 +597,6 @@ BitVector X86RegisterInfo::getReservedRegs(const MachineFunction &MF) const {        Reserved.set(*AI);    } -  // Reserve low half pair registers in case they are used by RA aggressively. -  Reserved.set(X86::TMM0_TMM1); -  Reserved.set(X86::TMM2_TMM3); -    assert(checkAllSuperRegsMarked(Reserved,                                   {X86::SIL, X86::DIL, X86::BPL, X86::SPL,                                    X86::SIH, X86::DIH, X86::BPH, X86::SPH})); @@ -621,7 +617,7 @@ unsigned X86RegisterInfo::getNumSupportedRegs(const MachineFunction &MF) const {    // and try to return the minimum number of registers supported by the target.    static_assert((X86::R15WH + 1 == X86::YMM0) && (X86::YMM15 + 1 == X86::K0) &&                      (X86::K6_K7 + 1 == X86::TMMCFG) && -                    (X86::TMM6_TMM7 + 1 == X86::R16) && +                    (X86::TMM7 + 1 == X86::R16) &&                      (X86::R31WH + 1 == X86::NUM_TARGET_REGS),                  "Register number may be incorrect"); @@ -694,8 +690,7 @@ bool X86RegisterInfo::isFixedRegister(const MachineFunction &MF,  }  bool X86RegisterInfo::isTileRegisterClass(const TargetRegisterClass *RC) const { -  return RC->getID() == X86::TILERegClassID || -         RC->getID() == X86::TILEPAIRRegClassID; +  return RC->getID() == X86::TILERegClassID;  }  void X86RegisterInfo::adjustStackMapLiveOutMask(uint32_t *Mask) const { @@ -1062,17 +1057,9 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,    case X86::PTDPFP16PSV:    case X86::PTCMMIMFP16PSV:    case X86::PTCMMRLFP16PSV: -  case X86::PTTRANSPOSEDV: -  case X86::PTTDPBF16PSV: -  case X86::PTTDPFP16PSV: -  case X86::PTTCMMIMFP16PSV: -  case X86::PTTCMMRLFP16PSV: -  case X86::PTCONJTCMMIMFP16PSV: -  case X86::PTCONJTFP16V:    case X86::PTILELOADDRSV:    case X86::PTILELOADDRST1V:    case X86::PTMMULTF32PSV: -  case X86::PTTMMULTF32PSV:    case X86::PTDPBF8PSV:    case X86::PTDPBHF8PSV:    case X86::PTDPHBF8PSV: @@ -1083,56 +1070,7 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,      VRM->assignVirt2Shape(VirtReg, Shape);      return Shape;    } -  case X86::PT2RPNTLVWZ0V: -  case X86::PT2RPNTLVWZ0T1V: -  case X86::PT2RPNTLVWZ1V: -  case X86::PT2RPNTLVWZ1T1V: -  case X86::PT2RPNTLVWZ0RSV: -  case X86::PT2RPNTLVWZ0RST1V: -  case X86::PT2RPNTLVWZ1RSV: -  case X86::PT2RPNTLVWZ1RST1V: { -    MachineOperand &MO1 = MI->getOperand(1); -    MachineOperand &MO2 = MI->getOperand(2); -    MachineOperand &MO3 = MI->getOperand(3); -    ShapeT Shape({&MO1, &MO2, &MO1, &MO3}, MRI); -    VRM->assignVirt2Shape(VirtReg, Shape); -    return Shape; -  } -  } -} - -static bool canHintShape(ShapeT &PhysShape, ShapeT &VirtShape) { -  unsigned PhysShapeNum = PhysShape.getShapeNum(); -  unsigned VirtShapeNum = VirtShape.getShapeNum(); - -  if (PhysShapeNum < VirtShapeNum) -    return false; - -  if (PhysShapeNum == VirtShapeNum) { -    if (PhysShapeNum == 1) -      return PhysShape == VirtShape; - -    for (unsigned I = 0; I < PhysShapeNum; I++) { -      ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I)); -      ShapeT VShape(VirtShape.getRow(I), VirtShape.getCol(I)); -      if (VShape != PShape) -        return false; -    } -    return true; -  } - -  // Hint subreg of mult-tile reg to single tile reg. -  if (VirtShapeNum == 1) { -    for (unsigned I = 0; I < PhysShapeNum; I++) { -      ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I)); -      if (VirtShape == PShape) -        return true; -    }    } - -  // Note: Currently we have no requirement for case of -  // (VirtShapeNum > 1 and PhysShapeNum > VirtShapeNum) -  return false;  }  bool X86RegisterInfo::getRegAllocationHints(Register VirtReg, @@ -1153,7 +1091,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,    if (!VRM)      return BaseImplRetVal; -  if (ID != X86::TILERegClassID && ID != X86::TILEPAIRRegClassID) { +  if (ID != X86::TILERegClassID) {      if (DisableRegAllocNDDHints || !ST.hasNDD() ||          !TRI.isGeneralPurposeRegisterClass(&RC))        return BaseImplRetVal; @@ -1204,7 +1142,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,        return;      }      ShapeT PhysShape = getTileShape(VReg, const_cast<VirtRegMap *>(VRM), MRI); -    if (canHintShape(PhysShape, VirtShape)) +    if (PhysShape == VirtShape)        Hints.push_back(PhysReg);    }; diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td index 99b7910..692e42a 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.td +++ b/llvm/lib/Target/X86/X86RegisterInfo.td @@ -30,8 +30,6 @@ let Namespace = "X86" in {    def sub_ymm      : SubRegIndex<256>;    def sub_mask_0   : SubRegIndex<-1>;    def sub_mask_1   : SubRegIndex<-1, -1>; -  def sub_t0       : SubRegIndex<8192>; -  def sub_t1       : SubRegIndex<8192, 8192>;  }  //===----------------------------------------------------------------------===// @@ -432,10 +430,6 @@ def TMM4:  X86Reg<"tmm4",   4>;  def TMM5:  X86Reg<"tmm5",   5>;  def TMM6:  X86Reg<"tmm6",   6>;  def TMM7:  X86Reg<"tmm7",   7>; -// TMM register pairs -def TPAIRS : RegisterTuples<[sub_t0, sub_t1], -                            [(add TMM0, TMM2, TMM4, TMM6), -                             (add TMM1, TMM3, TMM5, TMM7)]>;  }  // Floating point stack registers. These don't map one-to-one to the FP @@ -862,9 +856,6 @@ def VK64WM  : RegisterClass<"X86", [v64i1], 64, (add VK32WM)> {let Size = 64;}  let CopyCost = -1 in // Don't allow copying of tile registers  def TILE : RegisterClass<"X86", [x86amx], 8192,                           (sequence "TMM%u", 0, 7)> {let Size = 8192;} -// Need check alignment 3rd operand size=1024*2*8 -let isAllocatable = 1 in -def TILEPAIR : RegisterClass<"X86", [untyped], 512, (add TPAIRS)> {let Size = 16384;}  //===----------------------------------------------------------------------===//  // Register categories. diff --git a/llvm/lib/Target/X86/X86TileConfig.cpp b/llvm/lib/Target/X86/X86TileConfig.cpp index 17a44dd..09ef8fb 100644 --- a/llvm/lib/Target/X86/X86TileConfig.cpp +++ b/llvm/lib/Target/X86/X86TileConfig.cpp @@ -74,63 +74,6 @@ INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)  INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,                      false) -unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) { -  if (Reg.isVirtual()) { -    unsigned RegClassID = MRI->getRegClass(Reg)->getID(); -    if (RegClassID == X86::TILERegClassID) -      return 1; -    if (RegClassID == X86::TILEPAIRRegClassID) -      return 2; -  } else { -    if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; -  } -  return 0; -} - -static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM, -                                 Register VirtReg, -                                 SmallVector<ShapeT, 8> &Phys2Shapes) { -  unsigned Num = getAMXRegNum(MRI, VirtReg); -  MCRegister PhysReg = VRM.getPhys(VirtReg); -  if (!PhysReg) -    return; - -  if (Num == 1) { -    unsigned Index = PhysReg - X86::TMM0; -    if (!Phys2Shapes[Index].isValid()) { -      ShapeT Shape = VRM.getShape(VirtReg); -      Phys2Shapes[Index] = std::move(Shape); -      return; -    } -  } -  // Split tile pair shape info to 2 single tile shape info. e.g: -  // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes. -  if (Num == 2) { -    unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2; -    unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1; - -    ShapeT Shape = VRM.getShape(VirtReg); -    assert(Shape.getShapeNum() == 2 && "Unexpected shape number!"); - -    if (!Phys2Shapes[Index0].isValid()) { -      ShapeT Shape0(Shape.getRow(0), Shape.getCol(0), MRI); -      Phys2Shapes[Index0] = std::move(Shape0); -    } - -    if (!Phys2Shapes[Index1].isValid()) { -      ShapeT Shape1(Shape.getRow(1), Shape.getCol(1), MRI); -      Phys2Shapes[Index1] = std::move(Shape1); -    } -  } -} - -static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) { -  return getAMXRegNum(MRI, Reg) > 0; -} -  bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {    X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();    // Early exit in the common case of non-AMX code. @@ -138,7 +81,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {      return false;    const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); -  const TargetRegisterInfo *TRI = ST.getRegisterInfo(); +  const X86RegisterInfo *TRI = ST.getRegisterInfo();    const TargetInstrInfo *TII = ST.getInstrInfo();    MachineRegisterInfo &MRI = MF.getRegInfo();    LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); @@ -176,24 +119,29 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {    assert(ConstMI && "Cannot find an insertion point");    unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs(); -  SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT()); +  SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);    for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {      Register VirtReg = Register::index2VirtReg(I);      if (MRI.reg_nodbg_empty(VirtReg))        continue; -    if (!isAMXRegClass(&MRI, VirtReg)) +    if (!TRI->isTileRegisterClass(MRI.getRegClass(VirtReg))) +      continue; +    MCRegister PhysReg = VRM.getPhys(VirtReg); +    if (!PhysReg)        continue; -    collectVirtRegShapes(&MRI, VRM, VirtReg, Phys2Shapes); +    unsigned Index = PhysReg - X86::TMM0; +    if (!Phys2Virt[Index]) +      Phys2Virt[Index] = VirtReg;    }    // Fill in the shape of each tile physical register.    for (unsigned I = 0; I < AMXRegNum; ++I) { -    ShapeT Shape = Phys2Shapes[I]; -    if (!Shape.isValid()) +    if (!Phys2Virt[I])        continue;      DebugLoc DL;      bool IsRow = true;      MachineInstr *NewMI = nullptr; +    ShapeT Shape = VRM.getShape(Phys2Virt[I]);      for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {        // Here is the data format for the tile config.        // 0      palette @@ -222,14 +170,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {                     "Cannot initialize with different shapes");              continue;            } -          if (DefMI.getOperand(1).isImm()) { -            Imm = DefMI.getOperand(1).getImm(); -          } else { -            assert(DefMI.getOpcode() == X86::MOV32r0 && -                   "The opcode is assumed to be MOV32r0 if the operand is not " -                   "immediate."); -            Imm = 0; -          } +          Imm = DefMI.getOperand(1).getImm();            NewMI = addFrameReference(                        BuildMI(MF.front(), ++ConstMI->getIterator(), DL, | 
