diff options
Diffstat (limited to 'llvm/lib/Target')
49 files changed, 463 insertions, 1420 deletions
| diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index fede586..47c1ac4 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1032,6 +1032,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,      }      break;    } +  case Intrinsic::experimental_vector_extract_last_active: +    if (ST->isSVEorStreamingSVEAvailable()) { +      auto [LegalCost, _] = getTypeLegalizationCost(ICA.getArgTypes()[0]); +      // This should turn into chained clastb instructions. +      return LegalCost; +    } +    break;    default:      break;    } diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp index e187959..907f830 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp @@ -24,6 +24,7 @@  #include "llvm/CodeGen/GlobalISel/CSEInfo.h"  #include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h"  #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" +#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"  #include "llvm/CodeGen/GlobalISel/Utils.h"  #include "llvm/CodeGen/MachineFunctionPass.h"  #include "llvm/CodeGen/MachineUniformityAnalysis.h" @@ -34,9 +35,17 @@  using namespace llvm;  using namespace AMDGPU; +using namespace llvm::MIPatternMatch;  namespace { +// AMDGPU-specific pattern matchers +template <typename SrcTy> +inline UnaryOp_match<SrcTy, AMDGPU::G_AMDGPU_READANYLANE> +m_GAMDGPUReadAnyLane(const SrcTy &Src) { +  return UnaryOp_match<SrcTy, AMDGPU::G_AMDGPU_READANYLANE>(Src); +} +  class AMDGPURegBankLegalize : public MachineFunctionPass {  public:    static char ID; @@ -160,10 +169,18 @@ AMDGPURegBankLegalizeCombiner::tryMatchRALFromUnmerge(Register Src) {  Register AMDGPURegBankLegalizeCombiner::getReadAnyLaneSrc(Register Src) {    // Src = G_AMDGPU_READANYLANE RALSrc -  auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_AMDGPU_READANYLANE); -  if (RAL) +  Register RALSrc; +  if (mi_match(Src, MRI, m_GAMDGPUReadAnyLane(m_Reg(RALSrc))))      return RALSrc; +  // TruncSrc = G_AMDGPU_READANYLANE RALSrc +  // AextSrc = G_TRUNC TruncSrc +  // Src = G_ANYEXT AextSrc +  if (mi_match(Src, MRI, +               m_GAnyExt(m_GTrunc(m_GAMDGPUReadAnyLane(m_Reg(RALSrc)))))) { +    return RALSrc; +  } +    // LoVgpr, HiVgpr = G_UNMERGE_VALUES UnmergeSrc    // LoSgpr = G_AMDGPU_READANYLANE LoVgpr    // HiSgpr = G_AMDGPU_READANYLANE HiVgpr diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp index b84c30e..dc8fa7f 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp @@ -626,6 +626,23 @@ void RegBankLegalizeHelper::lowerSplitTo32(MachineInstr &MI) {    MI.eraseFromParent();  } +void RegBankLegalizeHelper::lowerSplitTo16(MachineInstr &MI) { +  Register Dst = MI.getOperand(0).getReg(); +  assert(MRI.getType(Dst) == V2S16); +  auto [Op1Lo32, Op1Hi32] = unpackAExt(MI.getOperand(1).getReg()); +  auto [Op2Lo32, Op2Hi32] = unpackAExt(MI.getOperand(2).getReg()); +  unsigned Opc = MI.getOpcode(); +  auto Flags = MI.getFlags(); +  auto Op1Lo = B.buildTrunc(SgprRB_S16, Op1Lo32); +  auto Op1Hi = B.buildTrunc(SgprRB_S16, Op1Hi32); +  auto Op2Lo = B.buildTrunc(SgprRB_S16, Op2Lo32); +  auto Op2Hi = B.buildTrunc(SgprRB_S16, Op2Hi32); +  auto Lo = B.buildInstr(Opc, {SgprRB_S16}, {Op1Lo, Op2Lo}, Flags); +  auto Hi = B.buildInstr(Opc, {SgprRB_S16}, {Op1Hi, Op2Hi}, Flags); +  B.buildMergeLikeInstr(Dst, {Lo, Hi}); +  MI.eraseFromParent(); +} +  void RegBankLegalizeHelper::lowerSplitTo32Select(MachineInstr &MI) {    Register Dst = MI.getOperand(0).getReg();    LLT DstTy = MRI.getType(Dst); @@ -698,6 +715,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,      return lowerUnpackBitShift(MI);    case UnpackMinMax:      return lowerUnpackMinMax(MI); +  case ScalarizeToS16: +    return lowerSplitTo16(MI);    case Ext32To64: {      const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());      MachineInstrBuilder Hi; @@ -849,6 +868,7 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {      return LLT::scalar(32);    case Sgpr64:    case Vgpr64: +  case UniInVgprS64:      return LLT::scalar(64);    case Sgpr128:    case Vgpr128: @@ -972,6 +992,7 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {    case UniInVcc:    case UniInVgprS16:    case UniInVgprS32: +  case UniInVgprS64:    case UniInVgprV2S16:    case UniInVgprV4S32:    case UniInVgprB32: @@ -1104,6 +1125,7 @@ void RegBankLegalizeHelper::applyMappingDst(        break;      }      case UniInVgprS32: +    case UniInVgprS64:      case UniInVgprV2S16:      case UniInVgprV4S32: {        assert(Ty == getTyFromID(MethodIDs[OpIdx])); diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h index ad3ff1d..e7598f8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h @@ -72,6 +72,7 @@ class RegBankLegalizeHelper {    static constexpr LLT P6 = LLT::pointer(6, 32);    MachineRegisterInfo::VRegAttrs SgprRB_S32 = {SgprRB, S32}; +  MachineRegisterInfo::VRegAttrs SgprRB_S16 = {SgprRB, S16};    MachineRegisterInfo::VRegAttrs VgprRB_S32 = {VgprRB, S32};    MachineRegisterInfo::VRegAttrs VccRB_S1 = {VccRB, S1}; @@ -121,6 +122,7 @@ private:    void lowerV_BFE(MachineInstr &MI);    void lowerS_BFE(MachineInstr &MI);    void lowerSplitTo32(MachineInstr &MI); +  void lowerSplitTo16(MachineInstr &MI);    void lowerSplitTo32Select(MachineInstr &MI);    void lowerSplitTo32SExtInReg(MachineInstr &MI);    void lowerUnpackMinMax(MachineInstr &MI); diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp index 01abd35..b22e9bd 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp @@ -918,9 +918,20 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,    bool hasSALUFloat = ST->hasSALUFloatInsts();    addRulesForGOpcs({G_FADD}, Standard) +      .Uni(S16, {{UniInVgprS16}, {Vgpr16, Vgpr16}}, !hasSALUFloat) +      .Uni(S16, {{Sgpr16}, {Sgpr16, Sgpr16}}, hasSALUFloat) +      .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})        .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}}, hasSALUFloat)        .Uni(S32, {{UniInVgprS32}, {Vgpr32, Vgpr32}}, !hasSALUFloat) -      .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}); +      .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}) +      .Uni(S64, {{UniInVgprS64}, {Vgpr64, Vgpr64}}) +      .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr64}}) +      .Uni(V2S16, {{UniInVgprV2S16}, {VgprV2S16, VgprV2S16}}, !hasSALUFloat) +      .Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, ScalarizeToS16}, +           hasSALUFloat) +      .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}}) +      .Any({{UniV2S32}, {{UniInVgprV2S32}, {VgprV2S32, VgprV2S32}}}) +      .Any({{DivV2S32}, {{VgprV2S32}, {VgprV2S32, VgprV2S32}}});    addRulesForGOpcs({G_FPTOUI})        .Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}}, hasSALUFloat) diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h index 030bd75..e6df5d8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h @@ -92,8 +92,10 @@ enum UniformityLLTOpPredicateID {    V4S32,    UniV2S16, +  UniV2S32,    DivV2S16, +  DivV2S32,    // B types    B32, @@ -178,7 +180,9 @@ enum RegBankLLTMappingApplyID {    UniInVcc,    UniInVgprS16,    UniInVgprS32, +  UniInVgprS64,    UniInVgprV2S16, +  UniInVgprV2S32,    UniInVgprV4S32,    UniInVgprB32,    UniInVgprB64, @@ -217,6 +221,7 @@ enum LoweringMethodID {    V_BFE,    VgprToVccCopy,    SplitTo32, +  ScalarizeToS16,    SplitTo32Select,    SplitTo32SExtInReg,    Ext32To64, diff --git a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td index 1637b91..d19920c 100644 --- a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td +++ b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td @@ -612,6 +612,9 @@ let Predicates = [UseHVX] in {             (V6_vandvrt HvxVR:$Vs, (ToI32 0x01010101))>;    def: Pat<(VecQ32 (trunc HVI32:$Vs)),             (V6_vandvrt HvxVR:$Vs, (ToI32 0x01010101))>; +  def: Pat<(VecQ16 (trunc HWI32:$Vss)), +           (Combineq(VecQ32(V6_vandvrt (HiVec $Vss), (ToI32 0x01010101))), +           (VecQ32 (V6_vandvrt (LoVec $Vss), (ToI32 0x01010101))))>;  }  let Predicates = [UseHVX] in { diff --git a/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td index 690dd73..e86b21c 100644 --- a/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td @@ -365,6 +365,7 @@ def : Pat<(f32 (uint_to_fp (i64 (sexti32 (i64 GPR:$src))))),  // FP Rounding  let Predicates = [HasBasicF, IsLA64] in {  def : PatFpr<frint, FRINT_S, FPR32>; +def : PatFpr<flog2, FLOGB_S, FPR32>;  } // Predicates = [HasBasicF, IsLA64]  let Predicates = [HasBasicF, IsLA32] in { diff --git a/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td index daefbaa..2e88254 100644 --- a/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td @@ -348,6 +348,7 @@ def : Pat<(bitconvert FPR64:$src), (MOVFR2GR_D FPR64:$src)>;  // FP Rounding  let Predicates = [HasBasicD, IsLA64] in {  def : PatFpr<frint, FRINT_D, FPR64>; +def : PatFpr<flog2, FLOGB_D, FPR64>;  } // Predicates = [HasBasicD, IsLA64]  /// Pseudo-instructions needed for the soft-float ABI with LA32D diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 80c96c6..a6de839 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -244,8 +244,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,      setOperationAction(ISD::FP_TO_BF16, MVT::f32,                         Subtarget.isSoftFPABI() ? LibCall : Custom); -    if (Subtarget.is64Bit()) +    if (Subtarget.is64Bit()) {        setOperationAction(ISD::FRINT, MVT::f32, Legal); +      setOperationAction(ISD::FLOG2, MVT::f32, Legal); +    }      if (!Subtarget.hasBasicD()) {        setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom); @@ -291,8 +293,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,      setOperationAction(ISD::FP_TO_BF16, MVT::f64,                         Subtarget.isSoftFPABI() ? LibCall : Custom); -    if (Subtarget.is64Bit()) +    if (Subtarget.is64Bit()) {        setOperationAction(ISD::FRINT, MVT::f64, Legal); +      setOperationAction(ISD::FLOG2, MVT::f64, Legal); +    }    }    // Set operations for 'LSX' feature. @@ -362,6 +366,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,        setOperationAction(ISD::FMA, VT, Legal);        setOperationAction(ISD::FSQRT, VT, Legal);        setOperationAction(ISD::FNEG, VT, Legal); +      setOperationAction(ISD::FLOG2, VT, Legal);        setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,                           ISD::SETUGE, ISD::SETUGT},                          VT, Expand); @@ -443,6 +448,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,        setOperationAction(ISD::FMA, VT, Legal);        setOperationAction(ISD::FSQRT, VT, Legal);        setOperationAction(ISD::FNEG, VT, Legal); +      setOperationAction(ISD::FLOG2, VT, Legal);        setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,                           ISD::SETUGE, ISD::SETUGT},                          VT, Expand); diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td index 613dea6..ca4ee5f 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td @@ -1593,6 +1593,9 @@ def : Pat<(fma_nsz (fneg v4f64:$xj), v4f64:$xk, v4f64:$xa),  // XVFSQRT_{S/D}  defm : PatXrF<fsqrt, "XVFSQRT">; +// XVFLOGB_{S/D} +defm : PatXrF<flog2, "XVFLOGB">; +  // XVRECIP_{S/D}  def : Pat<(fdiv vsplatf32_fpimm_eq_1, v8f32:$xj),            (XVFRECIP_S v8f32:$xj)>; @@ -2024,6 +2027,24 @@ def : Pat<(v4i32(fp_to_uint v4f64:$vj)),                 (XVFTINTRZ_LU_D v4f64:$vj)),                sub_128)>; +// XVAVG_{B/H/W/D/BU/HU/WU/DU}, XVAVGR_{B/H/W/D/BU/HU/WU/DU} +defm : VAvgPat<sra, "XVAVG_B", v32i8>; +defm : VAvgPat<sra, "XVAVG_H", v16i16>; +defm : VAvgPat<sra, "XVAVG_W", v8i32>; +defm : VAvgPat<sra, "XVAVG_D", v4i64>; +defm : VAvgPat<srl, "XVAVG_BU", v32i8>; +defm : VAvgPat<srl, "XVAVG_HU", v16i16>; +defm : VAvgPat<srl, "XVAVG_WU", v8i32>; +defm : VAvgPat<srl, "XVAVG_DU", v4i64>; +defm : VAvgrPat<sra, "XVAVGR_B", v32i8>; +defm : VAvgrPat<sra, "XVAVGR_H", v16i16>; +defm : VAvgrPat<sra, "XVAVGR_W", v8i32>; +defm : VAvgrPat<sra, "XVAVGR_D", v4i64>; +defm : VAvgrPat<srl, "XVAVGR_BU", v32i8>; +defm : VAvgrPat<srl, "XVAVGR_HU", v16i16>; +defm : VAvgrPat<srl, "XVAVGR_WU", v8i32>; +defm : VAvgrPat<srl, "XVAVGR_DU", v4i64>; +  // abs  def : Pat<(abs v32i8:$xj), (XVSIGNCOV_B v32i8:$xj, v32i8:$xj)>;  def : Pat<(abs v16i16:$xj), (XVSIGNCOV_H v16i16:$xj, v16i16:$xj)>; diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td index 4619c6b..92402ba 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td @@ -1518,6 +1518,18 @@ multiclass InsertExtractPatV2<ValueType vecty, ValueType elemty> {    }  } +multiclass VAvgPat<SDPatternOperator OpNode, string Inst, ValueType vt> { +  def : Pat<(OpNode (vt (add vt:$vj, vt:$vk)), (vt (vsplat_imm_eq_1))), +            (!cast<LAInst>(Inst) vt:$vj, vt:$vk)>; +} + +multiclass VAvgrPat<SDPatternOperator OpNode, string Inst, ValueType vt> { +  def : Pat<(OpNode (vt (add (vt (add vt:$vj, vt:$vk)), +                             (vt (vsplat_imm_eq_1)))), +                    (vt (vsplat_imm_eq_1))), +            (!cast<LAInst>(Inst) vt:$vj, vt:$vk)>; +} +  let Predicates = [HasExtLSX] in {  // VADD_{B/H/W/D} @@ -1783,6 +1795,9 @@ def : Pat<(fma_nsz (fneg v2f64:$vj), v2f64:$vk, v2f64:$va),  // VFSQRT_{S/D}  defm : PatVrF<fsqrt, "VFSQRT">; +// VFLOGB_{S/D} +defm : PatVrF<flog2, "VFLOGB">; +  // VFRECIP_{S/D}  def : Pat<(fdiv vsplatf32_fpimm_eq_1, v4f32:$vj),            (VFRECIP_S v4f32:$vj)>; @@ -2154,6 +2169,24 @@ def : Pat<(f32 f32imm_vldi:$in),  def : Pat<(f64 f64imm_vldi:$in),            (f64 (EXTRACT_SUBREG (VLDI (to_f64imm_vldi f64imm_vldi:$in)), sub_64))>; +// VAVG_{B/H/W/D/BU/HU/WU/DU}, VAVGR_{B/H/W/D/BU/HU/WU/DU} +defm : VAvgPat<sra, "VAVG_B", v16i8>; +defm : VAvgPat<sra, "VAVG_H", v8i16>; +defm : VAvgPat<sra, "VAVG_W", v4i32>; +defm : VAvgPat<sra, "VAVG_D", v2i64>; +defm : VAvgPat<srl, "VAVG_BU", v16i8>; +defm : VAvgPat<srl, "VAVG_HU", v8i16>; +defm : VAvgPat<srl, "VAVG_WU", v4i32>; +defm : VAvgPat<srl, "VAVG_DU", v2i64>; +defm : VAvgrPat<sra, "VAVGR_B", v16i8>; +defm : VAvgrPat<sra, "VAVGR_H", v8i16>; +defm : VAvgrPat<sra, "VAVGR_W", v4i32>; +defm : VAvgrPat<sra, "VAVGR_D", v2i64>; +defm : VAvgrPat<srl, "VAVGR_BU", v16i8>; +defm : VAvgrPat<srl, "VAVGR_HU", v8i16>; +defm : VAvgrPat<srl, "VAVGR_WU", v4i32>; +defm : VAvgrPat<srl, "VAVGR_DU", v2i64>; +  // abs  def : Pat<(abs v16i8:$vj), (VSIGNCOV_B v16i8:$vj, v16i8:$vj)>;  def : Pat<(abs v8i16:$vj), (VSIGNCOV_H v8i16:$vj, v8i16:$vj)>; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 7e7ee75..c667a09 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1871,17 +1871,6 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }    (is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, _CH))          \           : (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, ))) -#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode, is_mc, is_ch, is_s32)   \ -  [&]() -> auto {                                                              \ -    if (is_mc && is_ch)                                                        \ -      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC_CH);      \ -    if (is_ch)                                                                 \ -      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _CH);         \ -    if (is_mc)                                                                 \ -      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC);         \ -    return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, );              \ -  }() -  static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim,                                                         bool IsShared32,                                                         bool IsCacheHint, @@ -1925,112 +1914,6 @@ static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim,    }  } -static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32, -                                              bool IsMultiCast, -                                              bool IsCacheHint, bool IsIm2Col) { -  if (IsIm2Col) { -    switch (Dim) { -    case 3: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 4: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 5: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    default: -      llvm_unreachable("Invalid Dimension in im2col mode for " -                       "GetCpAsyncBulkTensorG2SOpcode."); -    } -  } else { -    switch (Dim) { -    case 1: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 2: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 3: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 4: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 5: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    default: -      llvm_unreachable( -          "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode."); -    } -  } -} - -static size_t GetDimsFromIntrinsic(unsigned IID) { -  switch (IID) { -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d: -    return 3; -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d: -    return 4; -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d: -    return 5; -  default: -    llvm_unreachable("Invalid im2col intrinsic in GetDimsFromIntrinsic."); -  } -} - -void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N, -                                                         bool IsIm2Col) { -  // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: -  // {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2} -  // multicast, cache_hint, -  // multicast_flag, cache_hint_flag, cta_group_flag} -  // NumOperands = {Chain, IID} + {Actual intrinsic args} -  //             = {2}          + {8 + dims + im2col_offsets} -  size_t NumOps = N->getNumOperands(); -  size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1)) -                            : (NumOps - 10); -  // Offsets is always 'NumDims - 2' and only for im2col mode -  size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0; -  bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1; -  bool IsMultiCast = N->getConstantOperandVal(NumOps - 3) == 1; -  size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src} -  size_t MultiCastIdx = NumBaseArgs + 2;         // for Chain and IID - -  unsigned CTAGroupVal = N->getConstantOperandVal(NumOps - 1); -  if ((CTAGroupVal > 0) && !Subtarget->hasCpAsyncBulkTensorCTAGroupSupport()) -    report_fatal_error( -        formatv("CpAsyncBulkTensorG2S cta_group::1/2 is not supported on sm_{}", -                Subtarget->getSmVersion())); - -  SDLoc DL(N); -  SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs)); - -  // Push MultiCast operand, if available -  if (IsMultiCast) -    Ops.push_back(N->getOperand(MultiCastIdx)); - -  // Push CacheHint operand, if available -  if (IsCacheHint) -    Ops.push_back(N->getOperand(MultiCastIdx + 1)); - -  // Flag for CTA Group -  Ops.push_back(getI32Imm(CTAGroupVal, DL)); - -  // Finally, the chain operand -  Ops.push_back(N->getOperand(0)); - -  bool IsShared32 = -      CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32; -  unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode( -      NumDims, IsShared32, IsMultiCast, IsCacheHint, IsIm2Col); -  ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); -} -  void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,                                                              unsigned RedOp,                                                              bool IsIm2Col) { @@ -2175,18 +2058,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {    switch (IID) {    default:      return false; -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: -    SelectCpAsyncBulkTensorG2SCommon(N); -    return true; -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d: -    SelectCpAsyncBulkTensorG2SCommon(N, /*IsIm2Col=*/true); -    return true;    case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d:    case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d:    case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d: diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index c912e70..1cb579b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -86,7 +86,6 @@ private:    bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);    void SelectV2I64toI128(SDNode *N);    void SelectI128toV2I64(SDNode *N); -  void SelectCpAsyncBulkTensorG2SCommon(SDNode *N, bool IsIm2Col = false);    void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,                                             bool IsIm2Col = false);    void SelectTcgen05Ld(SDNode *N, bool hasOffset = false); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index dfde0cc..b260221 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -139,7 +139,6 @@ def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;  def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;  def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;  def hasTcgen05MMAScaleInputDImm : Predicate<"Subtarget->hasTcgen05MMAScaleInputDImm()">; -def hasTMACTAGroupSupport  : Predicate<"Subtarget->hasCpAsyncBulkTensorCTAGroupSupport()">;  def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;  class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index c923f0e..e8758aa 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -599,75 +599,15 @@ class TMA_IM2COL_UTIL<int dim, string mode> {    string base_str = !interleave(!foreach(i, !range(offsets), "$im2col" # i), ", ");  } -// From Global to Shared memory (G2S) -class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> { -  string prefix = "cp.async.bulk.tensor"; -  string dir = "shared::cluster.global"; -  string completion = "mbarrier::complete_tx::bytes"; -  string inst_name = prefix -                     # "." # dim # "d" -                     # "." # dir -                     # "." # mode -                     # "." # completion -                     # !if(mc, ".multicast::cluster", "") -                     # !if(ch, ".L2::cache_hint", ""); -  string intr_name = "CP_ASYNC_BULK_TENSOR_G2S_" -                     # dim # "D" -                     # !if(is_shared32, "_SHARED32", "") -                     # !if(!eq(mode, "tile"), "_TILE", "_IM2COL"); -} -  def CTAGroupFlags : Operand<i32> {    let PrintMethod = "printCTAGroup";  } -multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> { -  defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; -  defvar dims_str = TMA_DIMS_UTIL<dim>.base_str; -  defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]"; -  defvar rc = !if(is_shared32, B32, B64); - -  defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0); -  defvar im2col_dag = !if(!eq(mode, "im2col"), -    !dag(ins, !listsplat(B16, num_im2col), !foreach(i, !range(num_im2col), "im2col" # i)), -    (ins)); -  defvar im2col_str = !interleave(!foreach(i, !range(num_im2col), "$im2col" # i), ", "); -  defvar im2col_asm_str = ", {{" # im2col_str # "}}"; - -  defvar asm_str = !if(!eq(mode, "im2col"), -    !strconcat(asm_str_default, im2col_asm_str), asm_str_default); +def tma_cta_group_imm0 : TImmLeaf<i32, [{return Imm == 0;}]>; +def tma_cta_group_imm_any : TImmLeaf<i32, [{return Imm >= 0;}]>; -  def "" : NVPTXInst<(outs), -            !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)), -            !strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";")>, -            Requires<[hasPTX<80>, hasSM<90>]>; -  def _MC : NVPTXInst<(outs), -                  !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, -                       (ins B16:$mc, CTAGroupFlags:$cg)), -                  !strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;")>, -                  Requires<[hasPTX<80>, hasSM<90>]>; -  def _CH : NVPTXInst<(outs), -                  !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, -                       (ins B64:$ch, CTAGroupFlags:$cg)), -                  !strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;")>, -                  Requires<[hasPTX<80>, hasSM<90>]>; -  def _MC_CH : NVPTXInst<(outs), -                     !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, -                          (ins B16:$mc, B64:$ch, CTAGroupFlags:$cg)), -                     !strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;")>, -                     Requires<[hasPTX<80>, hasSM<90>]>; -} - -foreach dim = [1, 2, 3, 4, 5] in { -  foreach shared32 = [true, false] in { -    foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in { -      defm G2S_STRINGS<dim, mode, 0, 0, shared32>.intr_name : -        CP_ASYNC_BULK_TENSOR_G2S_INTR<dim, shared32, mode>; -    } -  } -} - -multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> { +multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred, +                               TImmLeaf cta_group_type = tma_cta_group_imm_any> {    defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;    defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;    defvar asm_str_base = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]"; @@ -697,10 +637,10 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []>                           !setdagop(dims_dag, intr),                           !setdagop(im2col_dag, intr),                           (intr B16:$mc, B64:$ch)); -  defvar intr_dag_no_hints   = !con(intr_dag_base, (intr 0,  0,  timm:$cg)); -  defvar intr_dag_with_mc    = !con(intr_dag_base, (intr -1, 0,  timm:$cg)); -  defvar intr_dag_with_ch    = !con(intr_dag_base, (intr 0, -1,  timm:$cg)); -  defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, timm:$cg)); +  defvar intr_dag_no_hints   = !con(intr_dag_base, (intr 0,  0,  cta_group_type:$cg)); +  defvar intr_dag_with_mc    = !con(intr_dag_base, (intr -1, 0,  cta_group_type:$cg)); +  defvar intr_dag_with_ch    = !con(intr_dag_base, (intr 0, -1,  cta_group_type:$cg)); +  defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, cta_group_type:$cg));    def "" : NVPTXInst<(outs), ins_dag,               inst_name # asm_str # ";", @@ -719,14 +659,30 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []>                   [intr_dag_with_mc_ch]>,                   Requires<pred>;  } + +foreach dim = 1...5 in { +  defm TMA_G2S_TILE_CG0_ # dim # "D" +      : TMA_TENSOR_G2S_INTR<dim, "tile", [hasPTX<80>, hasSM<90>], +                            tma_cta_group_imm0>; +  defm TMA_G2S_TILE_ # dim # "D" +      : TMA_TENSOR_G2S_INTR<dim, "tile", +                            [callSubtarget<"hasTMABlackwellSupport">]>; +}  foreach dim = 3...5 in { +  defm TMA_G2S_IM2COL_CG0_ # dim # "D" +      : TMA_TENSOR_G2S_INTR<dim, "im2col", [hasPTX<80>, hasSM<90>], +                            tma_cta_group_imm0>; +  defm TMA_G2S_IM2COL_ # dim # "D" +      : TMA_TENSOR_G2S_INTR<dim, "im2col", +                            [callSubtarget<"hasTMABlackwellSupport">]>;    foreach mode = ["im2col_w", "im2col_w_128"] in {      defm TMA_G2S_ # !toupper(mode) # "_" # dim # "D" -      : TMA_TENSOR_G2S_INTR<dim, mode, [hasTMACTAGroupSupport]>; +        : TMA_TENSOR_G2S_INTR<dim, mode, +                              [callSubtarget<"hasTMABlackwellSupport">]>;    }  }  defm TMA_G2S_TILE_GATHER4_2D : TMA_TENSOR_G2S_INTR<5, "tile_gather4", -                               [hasTMACTAGroupSupport]>; +                               [callSubtarget<"hasTMABlackwellSupport">]>;  multiclass TMA_TENSOR_G2S_CTA_INTR<int dim, string mode, list<Predicate> pred = []> {    defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; @@ -784,7 +740,8 @@ foreach dim = 3...5 in {      : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w", [hasPTX<86>, hasSM<100>]>;    defm TMA_G2S_CTA_IM2COL_W_128_ # dim # "D" -    : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", [hasTMACTAGroupSupport]>; +    : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", +                              [callSubtarget<"hasTMABlackwellSupport">]>;  }  defm TMA_G2S_CTA_TILE_GATHER4_2D : TMA_TENSOR_G2S_CTA_INTR<5, "tile_gather4",                                     [hasPTX<86>, hasSM<100>]>; @@ -835,7 +792,7 @@ foreach dim = 1...5 in {    }  }  defm TMA_S2G_TILE_SCATTER4_2D : TMA_TENSOR_S2G_INTR<5, "tile_scatter4", -                                [hasTMACTAGroupSupport]>; +                                [callSubtarget<"hasTMABlackwellSupport">]>;  def TMAReductionFlags : Operand<i32> {    let PrintMethod = "printTmaReductionMode"; @@ -930,11 +887,11 @@ foreach dim = 3...5 in {    foreach mode = ["im2col_w", "im2col_w_128"] in {      defvar suffix = !toupper(mode) # "_" # dim # "D";      defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode, -                                   [hasTMACTAGroupSupport]>; +                                   [callSubtarget<"hasTMABlackwellSupport">]>;    }  }  defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4", -                                     [hasTMACTAGroupSupport]>; +                                     [callSubtarget<"hasTMABlackwellSupport">]>;  //Prefetchu and Prefetch diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index 194dbdc..021b1f6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -166,18 +166,15 @@ public:    // f32x2 instructions in Blackwell family    bool hasF32x2Instructions() const; -  // TMA G2S copy with cta_group::1/2 support -  bool hasCpAsyncBulkTensorCTAGroupSupport() const { -    // TODO: Update/tidy-up after the family-conditional support arrives -    switch (FullSmVersion) { -    case 1003: -    case 1013: -      return PTXVersion >= 86; -    case 1033: -      return PTXVersion >= 88; -    default: -      return false; -    } +  // Checks support for following in TMA: +  //  - cta_group::1/2 support +  //  - im2col_w/w_128 mode support +  //  - tile_gather4 mode support +  //  - tile_scatter4 mode support +  bool hasTMABlackwellSupport() const { +    return hasPTXWithFamilySMs(90, {100, 110}) || +           hasPTXWithFamilySMs(88, {100, 101}) || +           hasPTXWithAccelSMs(86, {100, 101});    }    // Prior to CUDA 12.3 ptxas did not recognize that the trap instruction diff --git a/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp b/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp index 000d296..4ff489d 100644 --- a/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp @@ -296,8 +296,9 @@ PPCTargetMachine::PPCTargetMachine(const Target &T, const Triple &TT,                                     std::optional<Reloc::Model> RM,                                     std::optional<CodeModel::Model> CM,                                     CodeGenOptLevel OL, bool JIT) -    : CodeGenTargetMachineImpl(T, TT.computeDataLayout(), TT, CPU, -                               computeFSAdditions(FS, OL, TT), Options, +    : CodeGenTargetMachineImpl(T, +                               TT.computeDataLayout(Options.MCOptions.ABIName), +                               TT, CPU, computeFSAdditions(FS, OL, TT), Options,                                 getEffectiveRelocModel(TT, RM),                                 getEffectivePPCCodeModel(TT, CM, JIT), OL),        TLOF(createTLOF(getTargetTriple())), diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp index 8198173..282cf5d 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp @@ -92,6 +92,10 @@ private:    void emitFence(AtomicOrdering FenceOrdering, SyncScope::ID FenceSSID,                   MachineIRBuilder &MIB) const;    bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const; +  void addVectorLoadStoreOperands(MachineInstr &I, +                                  SmallVectorImpl<SrcOp> &SrcOps, +                                  unsigned &CurOp, bool IsMasked, +                                  bool IsStrided) const;    bool selectIntrinsicWithSideEffects(MachineInstr &I,                                        MachineIRBuilder &MIB) const; @@ -716,6 +720,26 @@ static unsigned selectRegImmLoadStoreOp(unsigned GenericOpc, unsigned OpSize) {    return GenericOpc;  } +void RISCVInstructionSelector::addVectorLoadStoreOperands( +    MachineInstr &I, SmallVectorImpl<SrcOp> &SrcOps, unsigned &CurOp, +    bool IsMasked, bool IsStrided) const { +  // Base Pointer +  auto PtrReg = I.getOperand(CurOp++).getReg(); +  SrcOps.push_back(PtrReg); + +  // Stride +  if (IsStrided) { +    auto StrideReg = I.getOperand(CurOp++).getReg(); +    SrcOps.push_back(StrideReg); +  } + +  // Mask +  if (IsMasked) { +    auto MaskReg = I.getOperand(CurOp++).getReg(); +    SrcOps.push_back(MaskReg); +  } +} +  bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(      MachineInstr &I, MachineIRBuilder &MIB) const {    // Find the intrinsic ID. @@ -752,21 +776,7 @@ bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(        SrcOps.push_back(Register(RISCV::NoRegister));      } -    // Base Pointer -    auto PtrReg = I.getOperand(CurOp++).getReg(); -    SrcOps.push_back(PtrReg); - -    // Stride -    if (IsStrided) { -      auto StrideReg = I.getOperand(CurOp++).getReg(); -      SrcOps.push_back(StrideReg); -    } - -    // Mask -    if (IsMasked) { -      auto MaskReg = I.getOperand(CurOp++).getReg(); -      SrcOps.push_back(MaskReg); -    } +    addVectorLoadStoreOperands(I, SrcOps, CurOp, IsMasked, IsStrided);      RISCVVType::VLMUL LMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(VT));      const RISCV::VLEPseudo *P = @@ -795,6 +805,48 @@ bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(      I.eraseFromParent();      return constrainSelectedInstRegOperands(*PseudoMI, TII, TRI, RBI);    } +  case Intrinsic::riscv_vsm: +  case Intrinsic::riscv_vse: +  case Intrinsic::riscv_vse_mask: +  case Intrinsic::riscv_vsse: +  case Intrinsic::riscv_vsse_mask: { +    bool IsMasked = IntrinID == Intrinsic::riscv_vse_mask || +                    IntrinID == Intrinsic::riscv_vsse_mask; +    bool IsStrided = IntrinID == Intrinsic::riscv_vsse || +                     IntrinID == Intrinsic::riscv_vsse_mask; +    LLT VT = MRI->getType(I.getOperand(1).getReg()); +    unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits()); + +    // Sources +    unsigned CurOp = 1; +    SmallVector<SrcOp, 4> SrcOps; // Source registers. + +    // Store value +    auto PassthruReg = I.getOperand(CurOp++).getReg(); +    SrcOps.push_back(PassthruReg); + +    addVectorLoadStoreOperands(I, SrcOps, CurOp, IsMasked, IsStrided); + +    RISCVVType::VLMUL LMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(VT)); +    const RISCV::VSEPseudo *P = RISCV::getVSEPseudo( +        IsMasked, IsStrided, Log2SEW, static_cast<unsigned>(LMUL)); + +    auto PseudoMI = MIB.buildInstr(P->Pseudo, {}, SrcOps); + +    // Select VL +    auto VLOpFn = renderVLOp(I.getOperand(CurOp++)); +    for (auto &RenderFn : *VLOpFn) +      RenderFn(PseudoMI); + +    // SEW +    PseudoMI.addImm(Log2SEW); + +    // Memref +    PseudoMI.cloneMemRefs(I); + +    I.eraseFromParent(); +    return constrainSelectedInstRegOperands(*PseudoMI, TII, TRI, RBI); +  }    }  } diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp index 4105618..526675a 100644 --- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp +++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp @@ -127,6 +127,10 @@ bool RISCVExpandPseudo::expandMI(MachineBasicBlock &MBB,    case RISCV::PseudoCCAND:    case RISCV::PseudoCCOR:    case RISCV::PseudoCCXOR: +  case RISCV::PseudoCCMAX: +  case RISCV::PseudoCCMAXU: +  case RISCV::PseudoCCMIN: +  case RISCV::PseudoCCMINU:    case RISCV::PseudoCCADDW:    case RISCV::PseudoCCSUBW:    case RISCV::PseudoCCSLL: @@ -217,6 +221,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,          .addImm(0);    } else {      unsigned NewOpc; +    // clang-format off      switch (MI.getOpcode()) {      default:        llvm_unreachable("Unexpected opcode!"); @@ -228,6 +233,10 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,      case RISCV::PseudoCCAND:   NewOpc = RISCV::AND;   break;      case RISCV::PseudoCCOR:    NewOpc = RISCV::OR;    break;      case RISCV::PseudoCCXOR:   NewOpc = RISCV::XOR;   break; +    case RISCV::PseudoCCMAX:   NewOpc = RISCV::MAX;   break; +    case RISCV::PseudoCCMIN:   NewOpc = RISCV::MIN;   break; +    case RISCV::PseudoCCMAXU:  NewOpc = RISCV::MAXU;  break; +    case RISCV::PseudoCCMINU:  NewOpc = RISCV::MINU;  break;      case RISCV::PseudoCCADDI:  NewOpc = RISCV::ADDI;  break;      case RISCV::PseudoCCSLLI:  NewOpc = RISCV::SLLI;  break;      case RISCV::PseudoCCSRLI:  NewOpc = RISCV::SRLI;  break; @@ -250,6 +259,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,      case RISCV::PseudoCCNDS_BFOS: NewOpc = RISCV::NDS_BFOS; break;      case RISCV::PseudoCCNDS_BFOZ: NewOpc = RISCV::NDS_BFOZ; break;      } +    // clang-format on      if (NewOpc == RISCV::NDS_BFOZ || NewOpc == RISCV::NDS_BFOS) {        BuildMI(TrueBB, DL, TII->get(NewOpc), DestReg) diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index b4556f6..cfee6ab 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -1851,6 +1851,11 @@ def TuneShortForwardBranchOpt  def HasShortForwardBranchOpt : Predicate<"Subtarget->hasShortForwardBranchOpt()">;  def NoShortForwardBranchOpt : Predicate<"!Subtarget->hasShortForwardBranchOpt()">; +def TuneShortForwardBranchIMinMax +    : SubtargetFeature<"short-forward-branch-i-minmax", "HasShortForwardBranchIMinMax", +                       "true", "Enable short forward branch optimization for min,max instructions in Zbb", +                       [TuneShortForwardBranchOpt]>; +  // Some subtargets require a S2V transfer buffer to move scalars into vectors.  // FIXME: Forming .vx/.vf/.wx/.wf can reduce register pressure.  def TuneNoSinkSplatOperands diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 912b82d..c9df787 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -869,7 +869,7 @@ std::optional<unsigned> getFoldedOpcode(MachineFunction &MF, MachineInstr &MI,    }  } -// This is the version used during inline spilling +// This is the version used during InlineSpiller::spillAroundUses  MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl(      MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned> Ops,      MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS, @@ -1699,6 +1699,10 @@ unsigned getPredicatedOpcode(unsigned Opcode) {    case RISCV::AND:   return RISCV::PseudoCCAND;    case RISCV::OR:    return RISCV::PseudoCCOR;    case RISCV::XOR:   return RISCV::PseudoCCXOR; +  case RISCV::MAX:   return RISCV::PseudoCCMAX; +  case RISCV::MAXU:  return RISCV::PseudoCCMAXU; +  case RISCV::MIN:   return RISCV::PseudoCCMIN; +  case RISCV::MINU:  return RISCV::PseudoCCMINU;    case RISCV::ADDI:  return RISCV::PseudoCCADDI;    case RISCV::SLLI:  return RISCV::PseudoCCSLLI; @@ -1735,7 +1739,8 @@ unsigned getPredicatedOpcode(unsigned Opcode) {  /// return the defining instruction.  static MachineInstr *canFoldAsPredicatedOp(Register Reg,                                             const MachineRegisterInfo &MRI, -                                           const TargetInstrInfo *TII) { +                                           const TargetInstrInfo *TII, +                                           const RISCVSubtarget &STI) {    if (!Reg.isVirtual())      return nullptr;    if (!MRI.hasOneNonDBGUse(Reg)) @@ -1743,6 +1748,12 @@ static MachineInstr *canFoldAsPredicatedOp(Register Reg,    MachineInstr *MI = MRI.getVRegDef(Reg);    if (!MI)      return nullptr; + +  if (!STI.hasShortForwardBranchIMinMax() && +      (MI->getOpcode() == RISCV::MAX || MI->getOpcode() == RISCV::MIN || +       MI->getOpcode() == RISCV::MINU || MI->getOpcode() == RISCV::MAXU)) +    return nullptr; +    // Check if MI can be predicated and folded into the CCMOV.    if (getPredicatedOpcode(MI->getOpcode()) == RISCV::INSTRUCTION_LIST_END)      return nullptr; @@ -1806,10 +1817,10 @@ RISCVInstrInfo::optimizeSelect(MachineInstr &MI,    MachineRegisterInfo &MRI = MI.getParent()->getParent()->getRegInfo();    MachineInstr *DefMI = -      canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this); +      canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this, STI);    bool Invert = !DefMI;    if (!DefMI) -    DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this); +    DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this, STI);    if (!DefMI)      return nullptr; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index 7c89686..9cb53fb 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -768,7 +768,7 @@ def BGE  : BranchCC_rri<0b101, "bge">;  def BLTU : BranchCC_rri<0b110, "bltu">;  def BGEU : BranchCC_rri<0b111, "bgeu">; -let IsSignExtendingOpW = 1 in { +let IsSignExtendingOpW = 1, canFoldAsLoad = 1 in {  def LB  : Load_ri<0b000, "lb">, Sched<[WriteLDB, ReadMemBase]>;  def LH  : Load_ri<0b001, "lh">, Sched<[WriteLDH, ReadMemBase]>;  def LW  : Load_ri<0b010, "lw">, Sched<[WriteLDW, ReadMemBase]>; @@ -889,8 +889,10 @@ def CSRRCI : CSR_ii<0b111, "csrrci">;  /// RV64I instructions  let Predicates = [IsRV64] in { +let canFoldAsLoad = 1 in {  def LWU   : Load_ri<0b110, "lwu">, Sched<[WriteLDW, ReadMemBase]>;  def LD    : Load_ri<0b011, "ld">, Sched<[WriteLDD, ReadMemBase]>; +}  def SD    : Store_rri<0b011, "sd">, Sched<[WriteSTD, ReadStoreData, ReadMemBase]>;  let IsSignExtendingOpW = 1 in { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td index afac37d..4ffe3e6 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td @@ -71,6 +71,7 @@ defvar DExtsRV64 = [DExt, ZdinxExt];  //===----------------------------------------------------------------------===//  let Predicates = [HasStdExtD] in { +let canFoldAsLoad = 1 in  def FLD : FPLoad_r<0b011, "fld", FPR64, WriteFLD64>;  // Operands for stores are in the order srcreg, base, offset rather than diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td index 6571d99..b30f8ec 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td @@ -330,6 +330,7 @@ class PseudoFROUND<DAGOperand Ty, ValueType vt, ValueType intvt = XLenVT>  //===----------------------------------------------------------------------===//  let Predicates = [HasStdExtF] in { +let canFoldAsLoad = 1 in  def FLW : FPLoad_r<0b010, "flw", FPR32, WriteFLD32>;  // Operands for stores are in the order srcreg, base, offset rather than diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td index 0114fbd..5a67a5a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td @@ -106,6 +106,10 @@ def PseudoCCSRA : SFBALU_rr;  def PseudoCCAND : SFBALU_rr;  def PseudoCCOR  : SFBALU_rr;  def PseudoCCXOR : SFBALU_rr; +def PseudoCCMAX : SFBALU_rr; +def PseudoCCMIN : SFBALU_rr; +def PseudoCCMAXU : SFBALU_rr; +def PseudoCCMINU : SFBALU_rr;  def PseudoCCADDI : SFBALU_ri;  def PseudoCCANDI : SFBALU_ri; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 3fea21e..3f0424f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -3151,6 +3151,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,      return selectInsertElt(ResVReg, ResType, I);    case Intrinsic::spv_gep:      return selectGEP(ResVReg, ResType, I); +  case Intrinsic::spv_bitcast: { +    Register OpReg = I.getOperand(2).getReg(); +    SPIRVType *OpType = +        OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr; +    if (!GR.isBitcastCompatible(ResType, OpType)) +      report_fatal_error("incompatible result and operand types in a bitcast"); +    return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast); +  }    case Intrinsic::spv_unref_global:    case Intrinsic::spv_init_global: {      MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg()); diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 6e444c9..65dffc7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {    // Returns the loaded value.    Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,                                FixedVectorType *TargetType, Value *Source) { -    assert(TargetType->getNumElements() <= SourceType->getNumElements());      LoadInst *NewLoad = B.CreateLoad(SourceType, Source);      buildAssignType(B, SourceType, NewLoad);      Value *AssignValue = NewLoad;      if (TargetType->getElementType() != SourceType->getElementType()) { +      const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout(); +      [[maybe_unused]] TypeSize TargetTypeSize = +          DL.getTypeSizeInBits(TargetType); +      [[maybe_unused]] TypeSize SourceTypeSize = +          DL.getTypeSizeInBits(SourceType); +      assert(TargetTypeSize == SourceTypeSize);        AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,                                        {TargetType, SourceType}, {NewLoad});        buildAssignType(B, TargetType, AssignValue); +      return AssignValue;      } +    assert(TargetType->getNumElements() < SourceType->getNumElements());      SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());      for (unsigned I = 0; I < TargetType->getNumElements(); ++I)        Mask[I] = I; diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index db6f2d6..d538009 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -192,31 +192,43 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,          .addUse(OpReg);  } -// We do instruction selections early instead of calling MIB.buildBitcast() -// generating the general op code G_BITCAST. When MachineVerifier validates -// G_BITCAST we see a check of a kind: if Source Type is equal to Destination -// Type then report error "bitcast must change the type". This doesn't take into -// account the notion of a typed pointer that is important for SPIR-V where a -// user may and should use bitcast between pointers with different pointee types -// (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast). -// It's important for correct lowering in SPIR-V, because interpretation of the -// data type is not left to instructions that utilize the pointer, but encoded -// by the pointer declaration, and the SPIRV target can and must handle the -// declaration and use of pointers that specify the type of data they point to. -// It's not feasible to improve validation of G_BITCAST using just information -// provided by low level types of source and destination. Therefore we don't -// produce G_BITCAST as the general op code with semantics different from -// OpBitcast, but rather lower to OpBitcast immediately. As for now, the only -// difference would be that CombinerHelper couldn't transform known patterns -// around G_BUILD_VECTOR. See discussion -// in https://github.com/llvm/llvm-project/pull/110270 for even more context. -static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, -                             MachineIRBuilder MIB) { +// We lower G_BITCAST to OpBitcast here to avoid a MachineVerifier error. +// The verifier checks if the source and destination LLTs of a G_BITCAST are +// different, but this check is too strict for SPIR-V's typed pointers, which +// may have the same LLT but different SPIRVType (e.g. pointers to different +// pointee types). By lowering to OpBitcast here, we bypass the verifier's +// check. See discussion in https://github.com/llvm/llvm-project/pull/110270 +// for more context. +// +// We also handle the llvm.spv.bitcast intrinsic here. If the source and +// destination SPIR-V types are the same, we lower it to a COPY to enable +// further optimizations like copy propagation. +static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, +                          MachineIRBuilder MIB) {    SmallVector<MachineInstr *, 16> ToErase;    for (MachineBasicBlock &MBB : MF) {      for (MachineInstr &MI : MBB) { +      if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { +        Register DstReg = MI.getOperand(0).getReg(); +        Register SrcReg = MI.getOperand(2).getReg(); +        SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg); +        assert( +            DstType && +            "Expected destination SPIR-V type to have been assigned already."); +        SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg); +        assert(SrcType && +               "Expected source SPIR-V type to have been assigned already."); +        if (DstType == SrcType) { +          MIB.setInsertPt(*MI.getParent(), MI); +          MIB.buildCopy(DstReg, SrcReg); +          ToErase.push_back(&MI); +          continue; +        } +      } +        if (MI.getOpcode() != TargetOpcode::G_BITCAST)          continue; +        MIB.setInsertPt(*MI.getParent(), MI);        buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),                       MI.getOperand(1).getReg()); @@ -237,16 +249,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,    SmallVector<MachineInstr *, 10> ToErase;    for (MachineBasicBlock &MBB : MF) {      for (MachineInstr &MI : MBB) { -      if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && -          !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) +      if (!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))          continue;        assert(MI.getOperand(2).isReg());        MIB.setInsertPt(*MI.getParent(), MI);        ToErase.push_back(&MI); -      if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { -        MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg()); -        continue; -      }        Register Def = MI.getOperand(0).getReg();        Register Source = MI.getOperand(2).getReg();        Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0); @@ -1089,7 +1096,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {    removeImplicitFallthroughs(MF, MIB);    insertSpirvDecorations(MF, GR, MIB);    insertInlineAsm(MF, GR, ST, MIB); -  selectOpBitcasts(MF, GR, MIB); +  lowerBitcasts(MF, GR, MIB);    return true;  } diff --git a/llvm/lib/Target/X86/AsmParser/X86Operand.h b/llvm/lib/Target/X86/AsmParser/X86Operand.h index 89ac53e..a922725 100644 --- a/llvm/lib/Target/X86/AsmParser/X86Operand.h +++ b/llvm/lib/Target/X86/AsmParser/X86Operand.h @@ -620,37 +620,6 @@ struct X86Operand final : public MCParsedAsmOperand {      Inst.addOperand(MCOperand::createReg(Reg));    } -  bool isTILEPair() const { -    return Kind == Register && -           X86MCRegisterClasses[X86::TILERegClassID].contains(getReg()); -  } - -  void addTILEPairOperands(MCInst &Inst, unsigned N) const { -    assert(N == 1 && "Invalid number of operands!"); -    MCRegister Reg = getReg(); -    switch (Reg.id()) { -    default: -      llvm_unreachable("Invalid tile register!"); -    case X86::TMM0: -    case X86::TMM1: -      Reg = X86::TMM0_TMM1; -      break; -    case X86::TMM2: -    case X86::TMM3: -      Reg = X86::TMM2_TMM3; -      break; -    case X86::TMM4: -    case X86::TMM5: -      Reg = X86::TMM4_TMM5; -      break; -    case X86::TMM6: -    case X86::TMM7: -      Reg = X86::TMM6_TMM7; -      break; -    } -    Inst.addOperand(MCOperand::createReg(Reg)); -  } -    void addMemOperands(MCInst &Inst, unsigned N) const {      assert((N == 5) && "Invalid number of operands!");      if (getMemBaseReg()) diff --git a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp index 4927b45..7d2b5eb 100644 --- a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp +++ b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp @@ -810,10 +810,6 @@ static int readModRM(struct InternalInstruction *insn) {        if (index > 7)                                                           \          *valid = 0;                                                            \        return prefix##_TMM0 + index;                                            \ -    case TYPE_TMM_PAIR:                                                        \ -      if (index > 7)                                                           \ -        *valid = 0;                                                            \ -      return prefix##_TMM0_TMM1 + (index / 2);                                 \      case TYPE_VK:                                                              \        index &= 0xf;                                                            \        if (index > 7)                                                           \ @@ -2323,7 +2319,6 @@ static bool translateRM(MCInst &mcInst, const OperandSpecifier &operand,    case TYPE_YMM:    case TYPE_ZMM:    case TYPE_TMM: -  case TYPE_TMM_PAIR:    case TYPE_VK_PAIR:    case TYPE_VK:    case TYPE_DEBUGREG: diff --git a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h index dc9af2c..b0aa70b 100644 --- a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h +++ b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h @@ -535,12 +535,6 @@ namespace X86Disassembler {    ENTRY(TMM6)                                                                  \    ENTRY(TMM7) -#define REGS_TMM_PAIRS                                                         \ -  ENTRY(TMM0_TMM1)                                                             \ -  ENTRY(TMM2_TMM3)                                                             \ -  ENTRY(TMM4_TMM5)                                                             \ -  ENTRY(TMM6_TMM7) -  #define ALL_EA_BASES                                                           \    EA_BASES_16BIT                                                               \    EA_BASES_32BIT                                                               \ @@ -565,7 +559,6 @@ namespace X86Disassembler {    REGS_DEBUG                                                                   \    REGS_CONTROL                                                                 \    REGS_TMM                                                                     \ -  REGS_TMM_PAIRS                                                               \    ENTRY(RIP)  /// All possible values of the base field for effective-address diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp index 1c5f166..759d95e 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp @@ -467,22 +467,3 @@ void X86InstPrinterCommon::printVKPair(const MCInst *MI, unsigned OpNo,    }    llvm_unreachable("Unknown mask pair register name");  } - -void X86InstPrinterCommon::printTILEPair(const MCInst *MI, unsigned OpNo, -                                         raw_ostream &OS) { -  switch (MI->getOperand(OpNo).getReg()) { -  case X86::TMM0_TMM1: -    printRegName(OS, X86::TMM0); -    return; -  case X86::TMM2_TMM3: -    printRegName(OS, X86::TMM2); -    return; -  case X86::TMM4_TMM5: -    printRegName(OS, X86::TMM4); -    return; -  case X86::TMM6_TMM7: -    printRegName(OS, X86::TMM6); -    return; -  } -  llvm_unreachable("Unknown mask pair register name"); -} diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h index 2c9467c..cb55f2f 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h @@ -40,7 +40,6 @@ protected:                        const MCSubtargetInfo &STI);    void printOptionalSegReg(const MCInst *MI, unsigned OpNo, raw_ostream &O);    void printVKPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS); -  void printTILEPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS);  };  } // end namespace llvm diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td index a1fd366..9e291a6 100644 --- a/llvm/lib/Target/X86/X86.td +++ b/llvm/lib/Target/X86/X86.td @@ -274,9 +274,6 @@ def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true",  def FeatureAMXMOVRS : SubtargetFeature<"amx-movrs", "HasAMXMOVRS", "true",                                         "Support AMX-MOVRS instructions",                                         [FeatureAMXTILE]>; -def FeatureAMXTRANSPOSE : SubtargetFeature<"amx-transpose", "HasAMXTRANSPOSE", "true", -                                           "Support AMX amx-transpose instructions", -                                           [FeatureAMXTILE]>;  def FeatureAMXAVX512 : SubtargetFeature<"amx-avx512",                                          "HasAMXAVX512", "true",                                          "Support AMX-AVX512 instructions", @@ -1177,8 +1174,7 @@ def ProcessorFeatures {                                                    FeatureAMXMOVRS,                                                    FeatureAMXAVX512,                                                    FeatureAMXFP8, -                                                  FeatureAMXTF32, -                                                  FeatureAMXTRANSPOSE]; +                                                  FeatureAMXTF32];    list<SubtargetFeature> DMRFeatures =      !listconcat(GNRDFeatures, DMRAdditionalFeatures); diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp index 4a9b824..e3c44c0 100644 --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -649,149 +649,6 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,      MI.setDesc(TII->get(Opc));      return true;    } -  // TILEPAIRLOAD is just for TILEPair spill, we don't have corresponding -  // AMX instruction to support it. So, split it to 2 load instructions: -  // "TILEPAIRLOAD TMM0:TMM1, Base, Scale, Index, Offset, Segment" --> -  // "TILELOAD TMM0, Base, Scale, Index, Offset, Segment" + -  // "TILELOAD TMM1, Base, Scale, Index, Offset + TMM_SIZE, Segment" -  case X86::PTILEPAIRLOAD: { -    int64_t Disp = MBBI->getOperand(1 + X86::AddrDisp).getImm(); -    Register TReg = MBBI->getOperand(0).getReg(); -    bool DstIsDead = MBBI->getOperand(0).isDead(); -    Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0); -    Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1); -    unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8; - -    MachineInstrBuilder MIBLo = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD)) -            .addReg(TReg0, RegState::Define | getDeadRegState(DstIsDead)); -    MachineInstrBuilder MIBHi = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD)) -            .addReg(TReg1, RegState::Define | getDeadRegState(DstIsDead)); - -    for (int i = 0; i < X86::AddrNumOperands; ++i) { -      MIBLo.add(MBBI->getOperand(1 + i)); -      if (i == X86::AddrDisp) -        MIBHi.addImm(Disp + TmmSize); -      else -        MIBHi.add(MBBI->getOperand(1 + i)); -    } - -    // Make sure the first stride reg used in first tileload is alive. -    MachineOperand &Stride = -        MIBLo.getInstr()->getOperand(1 + X86::AddrIndexReg); -    Stride.setIsKill(false); - -    // Split the memory operand, adjusting the offset and size for the halves. -    MachineMemOperand *OldMMO = MBBI->memoperands().front(); -    MachineFunction *MF = MBB.getParent(); -    MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize); -    MachineMemOperand *MMOHi = -        MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize); - -    MIBLo.setMemRefs(MMOLo); -    MIBHi.setMemRefs(MMOHi); - -    // Delete the pseudo. -    MBB.erase(MBBI); -    return true; -  } -  // Similar with TILEPAIRLOAD, TILEPAIRSTORE is just for TILEPair spill, no -  // corresponding AMX instruction to support it. So, split it too: -  // "TILEPAIRSTORE Base, Scale, Index, Offset, Segment, TMM0:TMM1" --> -  // "TILESTORE Base, Scale, Index, Offset, Segment, TMM0" + -  // "TILESTORE Base, Scale, Index, Offset + TMM_SIZE, Segment, TMM1" -  case X86::PTILEPAIRSTORE: { -    int64_t Disp = MBBI->getOperand(X86::AddrDisp).getImm(); -    Register TReg = MBBI->getOperand(X86::AddrNumOperands).getReg(); -    bool SrcIsKill = MBBI->getOperand(X86::AddrNumOperands).isKill(); -    Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0); -    Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1); -    unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8; - -    MachineInstrBuilder MIBLo = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED)); -    MachineInstrBuilder MIBHi = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED)); - -    for (int i = 0; i < X86::AddrNumOperands; ++i) { -      MIBLo.add(MBBI->getOperand(i)); -      if (i == X86::AddrDisp) -        MIBHi.addImm(Disp + TmmSize); -      else -        MIBHi.add(MBBI->getOperand(i)); -    } -    MIBLo.addReg(TReg0, getKillRegState(SrcIsKill)); -    MIBHi.addReg(TReg1, getKillRegState(SrcIsKill)); - -    // Make sure the first stride reg used in first tilestore is alive. -    MachineOperand &Stride = MIBLo.getInstr()->getOperand(X86::AddrIndexReg); -    Stride.setIsKill(false); - -    // Split the memory operand, adjusting the offset and size for the halves. -    MachineMemOperand *OldMMO = MBBI->memoperands().front(); -    MachineFunction *MF = MBB.getParent(); -    MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize); -    MachineMemOperand *MMOHi = -        MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize); - -    MIBLo.setMemRefs(MMOLo); -    MIBHi.setMemRefs(MMOHi); - -    // Delete the pseudo. -    MBB.erase(MBBI); -    return true; -  } -  case X86::PT2RPNTLVWZ0V: -  case X86::PT2RPNTLVWZ0T1V: -  case X86::PT2RPNTLVWZ1V: -  case X86::PT2RPNTLVWZ1T1V: -  case X86::PT2RPNTLVWZ0RSV: -  case X86::PT2RPNTLVWZ0RST1V: -  case X86::PT2RPNTLVWZ1RSV: -  case X86::PT2RPNTLVWZ1RST1V: { -    for (unsigned i = 3; i > 0; --i) -      MI.removeOperand(i); -    unsigned Opc; -    switch (Opcode) { -    case X86::PT2RPNTLVWZ0V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0); -      break; -    case X86::PT2RPNTLVWZ0T1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1); -      break; -    case X86::PT2RPNTLVWZ1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1); -      break; -    case X86::PT2RPNTLVWZ1T1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1); -      break; -    case X86::PT2RPNTLVWZ0RSV: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS); -      break; -    case X86::PT2RPNTLVWZ0RST1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1); -      break; -    case X86::PT2RPNTLVWZ1RSV: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS); -      break; -    case X86::PT2RPNTLVWZ1RST1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1); -      break; -    default: -      llvm_unreachable("Impossible Opcode!"); -    } -    MI.setDesc(TII->get(Opc)); -    return true; -  } -  case X86::PTTRANSPOSEDV: -  case X86::PTCONJTFP16V: { -    for (int i = 2; i > 0; --i) -      MI.removeOperand(i); -    MI.setDesc(TII->get(Opcode == X86::PTTRANSPOSEDV ? X86::TTRANSPOSED -                                                     : X86::TCONJTFP16)); -    return true; -  }    case X86::PTCMMIMFP16PSV:    case X86::PTCMMRLFP16PSV:    case X86::PTDPBSSDV: @@ -800,13 +657,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,    case X86::PTDPBUUDV:    case X86::PTDPBF16PSV:    case X86::PTDPFP16PSV: -  case X86::PTTDPBF16PSV: -  case X86::PTTDPFP16PSV: -  case X86::PTTCMMIMFP16PSV: -  case X86::PTTCMMRLFP16PSV: -  case X86::PTCONJTCMMIMFP16PSV:    case X86::PTMMULTF32PSV: -  case X86::PTTMMULTF32PSV:    case X86::PTDPBF8PSV:    case X86::PTDPBHF8PSV:    case X86::PTDPHBF8PSV: @@ -816,6 +667,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,        MI.removeOperand(i);      unsigned Opc;      switch (Opcode) { +      // clang-format off      case X86::PTCMMIMFP16PSV:  Opc = X86::TCMMIMFP16PS; break;      case X86::PTCMMRLFP16PSV:  Opc = X86::TCMMRLFP16PS; break;      case X86::PTDPBSSDV:   Opc = X86::TDPBSSD; break; @@ -824,40 +676,12 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,      case X86::PTDPBUUDV:   Opc = X86::TDPBUUD; break;      case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break;      case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break; -    case X86::PTTDPBF16PSV: -      Opc = X86::TTDPBF16PS; -      break; -    case X86::PTTDPFP16PSV: -      Opc = X86::TTDPFP16PS; -      break; -    case X86::PTTCMMIMFP16PSV: -      Opc = X86::TTCMMIMFP16PS; -      break; -    case X86::PTTCMMRLFP16PSV: -      Opc = X86::TTCMMRLFP16PS; -      break; -    case X86::PTCONJTCMMIMFP16PSV: -      Opc = X86::TCONJTCMMIMFP16PS; -      break; -    case X86::PTMMULTF32PSV: -      Opc = X86::TMMULTF32PS; -      break; -    case X86::PTTMMULTF32PSV: -      Opc = X86::TTMMULTF32PS; -      break; -    case X86::PTDPBF8PSV: -      Opc = X86::TDPBF8PS; -      break; -    case X86::PTDPBHF8PSV: -      Opc = X86::TDPBHF8PS; -      break; -    case X86::PTDPHBF8PSV: -      Opc = X86::TDPHBF8PS; -      break; -    case X86::PTDPHF8PSV: -      Opc = X86::TDPHF8PS; -      break; - +    case X86::PTMMULTF32PSV: Opc = X86::TMMULTF32PS; break; +    case X86::PTDPBF8PSV: Opc = X86::TDPBF8PS; break; +    case X86::PTDPBHF8PSV: Opc = X86::TDPBHF8PS; break; +    case X86::PTDPHBF8PSV: Opc = X86::TDPHBF8PS; break; +    case X86::PTDPHF8PSV: Opc = X86::TDPHF8PS; break; +    // clang-format on      default:        llvm_unreachable("Unexpected Opcode");      } diff --git a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp index 787b71d..06f729a 100644 --- a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp @@ -267,24 +267,16 @@ void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI,                      << printReg(TileReg, TRI) << '\n');  } -static unsigned getTileDefNum(MachineRegisterInfo *MRI, Register Reg) { -  if (Reg.isVirtual()) { -    unsigned RegClassID = MRI->getRegClass(Reg)->getID(); -    if (RegClassID == X86::TILERegClassID) -      return 1; -    if (RegClassID == X86::TILEPAIRRegClassID) -      return 2; -  } else { -    if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; +static bool isTileRegister(MachineRegisterInfo *MRI, Register Reg) { +  if (Reg.isVirtual() && +      (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)) { +    return true;    } -  return 0; -} -static bool isTileRegister(MachineRegisterInfo *MRI, Register VirtReg) { -  return getTileDefNum(MRI, VirtReg) > 0; +  if (Reg >= X86::TMM0 && Reg <= X86::TMM7) +    return true; + +  return false;  }  static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) { @@ -296,7 +288,7 @@ static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {    if (!MO.isReg())      return false; -  return getTileDefNum(MRI, MO.getReg()) > 0; +  return isTileRegister(MRI, MO.getReg());  }  static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) { @@ -636,19 +628,7 @@ bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) {        else if (dominates(MBB, LastShapeMI, ColMI))          LastShapeMI = ColMI;      } -    unsigned TileDefNum = getTileDefNum(MRI, MI.getOperand(0).getReg()); -    if (TileDefNum > 1) { -      for (unsigned I = 1; I < TileDefNum; I++) { -        MachineOperand *ColxMO = &MI.getOperand(2 + I); -        MachineInstr *ColxMI = MRI->getVRegDef(ColxMO->getReg()); -        if (ColxMI->getParent() == &MBB) { -          if (!LastShapeMI) -            LastShapeMI = ColxMI; -          else if (dominates(MBB, LastShapeMI, ColxMI)) -            LastShapeMI = ColxMI; -        } -      } -    } +      // If there is user live out of the tilecfg, spill it and reload in      // before the user.      Register TileReg = MI.getOperand(0).getReg(); diff --git a/llvm/lib/Target/X86/X86FastTileConfig.cpp b/llvm/lib/Target/X86/X86FastTileConfig.cpp index 11d331b..d86ae36 100644 --- a/llvm/lib/Target/X86/X86FastTileConfig.cpp +++ b/llvm/lib/Target/X86/X86FastTileConfig.cpp @@ -77,14 +77,14 @@ INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,  INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,                      "Fast Tile Register Configure", false, false) -static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) { +static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {    // There is no phi instruction after register allocation.    assert(MI.isPHI() == false);    // The instruction must have 3 operands: tile def, row, col.    // It should be AMX pseudo instruction that have shape operand.    if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 ||        !MI.isPseudo()) -    return 0; +    return false;    MachineOperand &MO = MI.getOperand(0);    if (MO.isReg()) { @@ -93,24 +93,18 @@ static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) {      // register is not rewritten yet.      if (Reg.isVirtual()) {        if (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID) -        return 1; -      if (MRI->getRegClass(Reg)->getID() == X86::TILEPAIRRegClassID) -        return 2; +        return true;      }      if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; +      return true;    } -  return 0; +  return false;  }  static unsigned getTMMIndex(Register Reg) {    if (Reg >= X86::TMM0 && Reg <= X86::TMM7)      return Reg - X86::TMM0; -  if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -    return (Reg - X86::TMM0_TMM1) * 2;    llvm_unreachable("Invalid Tmm Reg!");  } @@ -120,17 +114,14 @@ bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) {    bool Change = false;    SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos;    for (MachineInstr &MI : reverse(MBB)) { -    unsigned DefNum = getNumDefTiles(MRI, MI); -    if (DefNum == 0 && MI.getOpcode() != X86::PLDTILECFGV) +    if (!isTileDef(MRI, MI) && MI.getOpcode() != X86::PLDTILECFGV)        continue;      // AMX instructions that define tile register.      if (MI.getOpcode() != X86::PLDTILECFGV) {        MachineOperand &Row = MI.getOperand(1);        unsigned TMMIdx = getTMMIndex(MI.getOperand(0).getReg()); -      for (unsigned I = 0; I < DefNum; I++) { -        MachineOperand &Col = MI.getOperand(2 + I); -        ShapeInfos.push_back({TMMIdx + I, ShapeT(&Row, &Col)}); -      } +      MachineOperand &Col = MI.getOperand(2); +      ShapeInfos.push_back({TMMIdx, ShapeT(&Row, &Col)});      } else { // PLDTILECFGV        // Rewrite the shape information to memory. Stack slot should have        // been initialized to zero in pre config. diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 4393f6e..d4418c8 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -337,23 +337,8 @@ namespace {      // lowering but before ISEL.      bool isAMXSDNode(SDNode *N) const {        // Check if N is AMX SDNode: -      // 1. check specific opcode since these carry MVT::Untyped instead of -      // x86amx_type; -      // 2. check result type; -      // 3. check operand type; -      switch (N->getOpcode()) { -      default: -        break; -      case X86::PT2RPNTLVWZ0V: -      case X86::PT2RPNTLVWZ0T1V: -      case X86::PT2RPNTLVWZ1V: -      case X86::PT2RPNTLVWZ1T1V: -      case X86::PT2RPNTLVWZ0RSV: -      case X86::PT2RPNTLVWZ0RST1V: -      case X86::PT2RPNTLVWZ1RSV: -      case X86::PT2RPNTLVWZ1RST1V: -        return true; -      } +      // 1. check result type; +      // 2. check operand type;        for (unsigned Idx = 0, E = N->getNumValues(); Idx != E; ++Idx) {          if (N->getValueType(Idx) == MVT::x86amx)            return true; @@ -5398,65 +5383,6 @@ void X86DAGToDAGISel::Select(SDNode *Node) {        ReplaceNode(Node, CNode);        return;      } -    case Intrinsic::x86_t2rpntlvwz0rs: -    case Intrinsic::x86_t2rpntlvwz0rst1: -    case Intrinsic::x86_t2rpntlvwz1rs: -    case Intrinsic::x86_t2rpntlvwz1rst1: -      if (!Subtarget->hasAMXMOVRS()) -        break; -      [[fallthrough]]; -    case Intrinsic::x86_t2rpntlvwz0: -    case Intrinsic::x86_t2rpntlvwz0t1: -    case Intrinsic::x86_t2rpntlvwz1: -    case Intrinsic::x86_t2rpntlvwz1t1: { -      if (!Subtarget->hasAMXTRANSPOSE()) -        break; -      auto *MFI = -          CurDAG->getMachineFunction().getInfo<X86MachineFunctionInfo>(); -      MFI->setAMXProgModel(AMXProgModelEnum::DirectReg); -      unsigned Opc; -      switch (IntNo) { -      default: -        llvm_unreachable("Unexpected intrinsic!"); -      case Intrinsic::x86_t2rpntlvwz0: -        Opc = X86::PT2RPNTLVWZ0; -        break; -      case Intrinsic::x86_t2rpntlvwz0t1: -        Opc = X86::PT2RPNTLVWZ0T1; -        break; -      case Intrinsic::x86_t2rpntlvwz1: -        Opc = X86::PT2RPNTLVWZ1; -        break; -      case Intrinsic::x86_t2rpntlvwz1t1: -        Opc = X86::PT2RPNTLVWZ1T1; -        break; -      case Intrinsic::x86_t2rpntlvwz0rs: -        Opc = X86::PT2RPNTLVWZ0RS; -        break; -      case Intrinsic::x86_t2rpntlvwz0rst1: -        Opc = X86::PT2RPNTLVWZ0RST1; -        break; -      case Intrinsic::x86_t2rpntlvwz1rs: -        Opc = X86::PT2RPNTLVWZ1RS; -        break; -      case Intrinsic::x86_t2rpntlvwz1rst1: -        Opc = X86::PT2RPNTLVWZ1RST1; -        break; -      } -      // FIXME: Match displacement and scale. -      unsigned TIndex = Node->getConstantOperandVal(2); -      SDValue TReg = getI8Imm(TIndex, dl); -      SDValue Base = Node->getOperand(3); -      SDValue Scale = getI8Imm(1, dl); -      SDValue Index = Node->getOperand(4); -      SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32); -      SDValue Segment = CurDAG->getRegister(0, MVT::i16); -      SDValue Chain = Node->getOperand(0); -      SDValue Ops[] = {TReg, Base, Scale, Index, Disp, Segment, Chain}; -      MachineSDNode *CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops); -      ReplaceNode(Node, CNode); -      return; -    }      }      break;    } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 49beada..007074c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -27946,67 +27946,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,        return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC,                           Operation.getValue(1));      } -    case Intrinsic::x86_t2rpntlvwz0rs_internal: -    case Intrinsic::x86_t2rpntlvwz0rst1_internal: -    case Intrinsic::x86_t2rpntlvwz1rs_internal: -    case Intrinsic::x86_t2rpntlvwz1rst1_internal: -    case Intrinsic::x86_t2rpntlvwz0_internal: -    case Intrinsic::x86_t2rpntlvwz0t1_internal: -    case Intrinsic::x86_t2rpntlvwz1_internal: -    case Intrinsic::x86_t2rpntlvwz1t1_internal: { -      auto *X86MFI = DAG.getMachineFunction().getInfo<X86MachineFunctionInfo>(); -      X86MFI->setAMXProgModel(AMXProgModelEnum::ManagedRA); -      unsigned IntNo = Op.getConstantOperandVal(1); -      unsigned Opc = 0; -      switch (IntNo) { -      default: -        llvm_unreachable("Unexpected intrinsic!"); -      case Intrinsic::x86_t2rpntlvwz0_internal: -        Opc = X86::PT2RPNTLVWZ0V; -        break; -      case Intrinsic::x86_t2rpntlvwz0t1_internal: -        Opc = X86::PT2RPNTLVWZ0T1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1_internal: -        Opc = X86::PT2RPNTLVWZ1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1t1_internal: -        Opc = X86::PT2RPNTLVWZ1T1V; -        break; -      case Intrinsic::x86_t2rpntlvwz0rs_internal: -        Opc = X86::PT2RPNTLVWZ0RSV; -        break; -      case Intrinsic::x86_t2rpntlvwz0rst1_internal: -        Opc = X86::PT2RPNTLVWZ0RST1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1rs_internal: -        Opc = X86::PT2RPNTLVWZ1RSV; -        break; -      case Intrinsic::x86_t2rpntlvwz1rst1_internal: -        Opc = X86::PT2RPNTLVWZ1RST1V; -        break; -      } - -      SDLoc DL(Op); -      SDVTList VTs = DAG.getVTList(MVT::Untyped, MVT::Other); - -      SDValue Ops[] = {Op.getOperand(2),                       // Row -                       Op.getOperand(3),                       // Col0 -                       Op.getOperand(4),                       // Col1 -                       Op.getOperand(5),                       // Base -                       DAG.getTargetConstant(1, DL, MVT::i8),  // Scale -                       Op.getOperand(6),                       // Index -                       DAG.getTargetConstant(0, DL, MVT::i32), // Disp -                       DAG.getRegister(0, MVT::i16),           // Segment -                       Op.getOperand(0)};                      // Chain - -      MachineSDNode *Res = DAG.getMachineNode(Opc, DL, VTs, Ops); -      SDValue Res0 = DAG.getTargetExtractSubreg(X86::sub_t0, DL, MVT::x86amx, -                                                SDValue(Res, 0)); -      SDValue Res1 = DAG.getTargetExtractSubreg(X86::sub_t1, DL, MVT::x86amx, -                                                SDValue(Res, 0)); -      return DAG.getMergeValues({Res0, Res1, SDValue(Res, 1)}, DL); -    }      case Intrinsic::x86_atomic_bts_rm:      case Intrinsic::x86_atomic_btc_rm:      case Intrinsic::x86_atomic_btr_rm: { @@ -37745,10 +37684,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,      assert (Imm < 8 && "Illegal tmm index");      return X86::TMM0 + Imm;    }; -  auto TMMImmToTMMPair = [](unsigned Imm) { -    assert(Imm < 8 && "Illegal tmm pair index."); -    return X86::TMM0_TMM1 + Imm / 2; -  };    switch (MI.getOpcode()) {    default:      llvm_unreachable("Unexpected instr type to insert"); @@ -38129,53 +38064,25 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,    case X86::PTDPBHF8PS:    case X86::PTDPHBF8PS:    case X86::PTDPHF8PS: -  case X86::PTTDPBF16PS: -  case X86::PTTDPFP16PS: -  case X86::PTTCMMIMFP16PS: -  case X86::PTTCMMRLFP16PS: -  case X86::PTCONJTCMMIMFP16PS: -  case X86::PTMMULTF32PS: -  case X86::PTTMMULTF32PS: { +  case X86::PTMMULTF32PS: {      unsigned Opc;      switch (MI.getOpcode()) {      default: llvm_unreachable("illegal opcode!"); +      // clang-format off      case X86::PTDPBSSD: Opc = X86::TDPBSSD; break;      case X86::PTDPBSUD: Opc = X86::TDPBSUD; break;      case X86::PTDPBUSD: Opc = X86::TDPBUSD; break;      case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;      case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;      case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break; -    case X86::PTCMMIMFP16PS: -      Opc = X86::TCMMIMFP16PS; -      break; -    case X86::PTCMMRLFP16PS: -      Opc = X86::TCMMRLFP16PS; -      break; +    case X86::PTCMMIMFP16PS: Opc = X86::TCMMIMFP16PS; break; +    case X86::PTCMMRLFP16PS: Opc = X86::TCMMRLFP16PS; break;      case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;      case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;      case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;      case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break; -    case X86::PTTDPBF16PS: -      Opc = X86::TTDPBF16PS; -      break; -    case X86::PTTDPFP16PS: -      Opc = X86::TTDPFP16PS; -      break; -    case X86::PTTCMMIMFP16PS: -      Opc = X86::TTCMMIMFP16PS; -      break; -    case X86::PTTCMMRLFP16PS: -      Opc = X86::TTCMMRLFP16PS; -      break; -    case X86::PTCONJTCMMIMFP16PS: -      Opc = X86::TCONJTCMMIMFP16PS; -      break; -    case X86::PTMMULTF32PS: -      Opc = X86::TMMULTF32PS; -      break; -    case X86::PTTMMULTF32PS: -      Opc = X86::TTMMULTF32PS; -      break; +    case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break; +      // clang-format on      }      MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc)); @@ -38246,70 +38153,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,      MI.eraseFromParent(); // The pseudo is gone now.      return BB;    } -  case X86::PT2RPNTLVWZ0: -  case X86::PT2RPNTLVWZ0T1: -  case X86::PT2RPNTLVWZ1: -  case X86::PT2RPNTLVWZ1T1: -  case X86::PT2RPNTLVWZ0RS: -  case X86::PT2RPNTLVWZ0RST1: -  case X86::PT2RPNTLVWZ1RS: -  case X86::PT2RPNTLVWZ1RST1: { -    const DebugLoc &DL = MI.getDebugLoc(); -    unsigned Opc; -#define GET_EGPR_IF_ENABLED(OPC) (Subtarget.hasEGPR() ? OPC##_EVEX : OPC) -    switch (MI.getOpcode()) { -    default: -      llvm_unreachable("Unexpected instruction!"); -    case X86::PT2RPNTLVWZ0: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0); -      break; -    case X86::PT2RPNTLVWZ0T1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1); -      break; -    case X86::PT2RPNTLVWZ1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1); -      break; -    case X86::PT2RPNTLVWZ1T1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1); -      break; -    case X86::PT2RPNTLVWZ0RS: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS); -      break; -    case X86::PT2RPNTLVWZ0RST1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1); -      break; -    case X86::PT2RPNTLVWZ1RS: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS); -      break; -    case X86::PT2RPNTLVWZ1RST1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1); -      break; -    } -#undef GET_EGPR_IF_ENABLED -    MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); -    MIB.addReg(TMMImmToTMMPair(MI.getOperand(0).getImm()), RegState::Define); - -    MIB.add(MI.getOperand(1)); // base -    MIB.add(MI.getOperand(2)); // scale -    MIB.add(MI.getOperand(3)); // index -    MIB.add(MI.getOperand(4)); // displacement -    MIB.add(MI.getOperand(5)); // segment -    MI.eraseFromParent();      // The pseudo is gone now. -    return BB; -  } -  case X86::PTTRANSPOSED: -  case X86::PTCONJTFP16: { -    const DebugLoc &DL = MI.getDebugLoc(); -    unsigned Opc = MI.getOpcode() == X86::PTTRANSPOSED ? X86::TTRANSPOSED -                                                       : X86::TCONJTFP16; - -    MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); -    MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); -    MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); - -    MI.eraseFromParent(); // The pseudo is gone now. -    return BB; -  }    case X86::PTCVTROWPS2BF16Hrri:    case X86::PTCVTROWPS2BF16Lrri:    case X86::PTCVTROWPS2PHHrri: @@ -53502,7 +53345,8 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,  }  // Look for a RMW operation that only touches one bit of a larger than legal -// type and fold it to a BTC/BTR/BTS pattern acting on a single i32 sub value. +// type and fold it to a BTC/BTR/BTS or bit insertion pattern acting on a single +// i32 sub value.  static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,                                SelectionDAG &DAG,                                const X86Subtarget &Subtarget) { @@ -53528,14 +53372,20 @@ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,    // BTR: X & ~(1 << ShAmt)    // BTS: X | (1 << ShAmt)    // BTC: X ^ (1 << ShAmt) -  SDValue ShAmt; +  // +  // BitInsert: (X & ~(1 << ShAmt)) | (InsertBit << ShAmt) +  SDValue InsertBit, ShAmt;    if (!StoredVal.hasOneUse() ||        !(sd_match(StoredVal, m_And(m_Specific(LoadVal),                                    m_Not(m_Shl(m_One(), m_Value(ShAmt))))) ||          sd_match(StoredVal,                   m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) ||          sd_match(StoredVal, -                 m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))))) +                 m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) || +        sd_match(StoredVal, +                 m_Or(m_And(m_Specific(LoadVal), +                            m_Not(m_Shl(m_One(), m_Value(ShAmt)))), +                      m_Shl(m_Value(InsertBit), m_Deferred(ShAmt))))))      return SDValue();    // Ensure the shift amount is in bounds. @@ -53543,6 +53393,13 @@ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,    if (KnownAmt.getMaxValue().uge(VT.getSizeInBits()))      return SDValue(); +  // If we're inserting a bit then it must be the LSB. +  if (InsertBit) { +    KnownBits KnownInsert = DAG.computeKnownBits(InsertBit); +    if (KnownInsert.countMinLeadingZeros() < (VT.getSizeInBits() - 1)) +      return SDValue(); +  } +    // Split the shift into an alignment shift that moves the active i32 block to    // the bottom bits for truncation and a modulo shift that can act on the i32.    EVT AmtVT = ShAmt.getValueType(); @@ -53550,6 +53407,7 @@ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,                                   DAG.getSignedConstant(-32LL, DL, AmtVT));    SDValue ModuloAmt =        DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT)); +  ModuloAmt = DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8);    // Compute the byte offset for the i32 block that is changed by the RMW.    // combineTruncate will adjust the load for us in a similar way. @@ -53564,13 +53422,23 @@ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,    SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt);    X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); -  SDValue Mask = -      DAG.getNode(ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32), -                  DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8)); -  if (StoredVal.getOpcode() == ISD::AND) -    Mask = DAG.getNOT(DL, Mask, MVT::i32); +  SDValue Mask = DAG.getNode(ISD::SHL, DL, MVT::i32, +                             DAG.getConstant(1, DL, MVT::i32), ModuloAmt); + +  SDValue Res; +  if (InsertBit) { +    SDValue BitMask = +        DAG.getNode(ISD::SHL, DL, MVT::i32, +                    DAG.getZExtOrTrunc(InsertBit, DL, MVT::i32), ModuloAmt); +    Res = +        DAG.getNode(ISD::AND, DL, MVT::i32, X, DAG.getNOT(DL, Mask, MVT::i32)); +    Res = DAG.getNode(ISD::OR, DL, MVT::i32, Res, BitMask); +  } else { +    if (StoredVal.getOpcode() == ISD::AND) +      Mask = DAG.getNOT(DL, Mask, MVT::i32); +    Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask); +  } -  SDValue Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask);    return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(),                        Align(), St->getMemOperand()->getFlags());  } @@ -54591,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,  static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,                                 const X86Subtarget &Subtarget,                                 const SDLoc &DL) { +  using namespace SDPatternMatch;    if (!VT.isVector() || !Subtarget.hasSSSE3())      return SDValue(); @@ -54600,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,      return SDValue();    SDValue SSatVal = detectSSatPattern(In, VT); -  if (!SSatVal || SSatVal.getOpcode() != ISD::ADD) +  if (!SSatVal)      return SDValue(); -  // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs -  // of multiplies from even/odd elements. -  SDValue N0 = SSatVal.getOperand(0); -  SDValue N1 = SSatVal.getOperand(1); - -  if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) -    return SDValue(); - -  SDValue N00 = N0.getOperand(0); -  SDValue N01 = N0.getOperand(1); -  SDValue N10 = N1.getOperand(0); -  SDValue N11 = N1.getOperand(1); - +  // See if this is a signed saturation of an ADD, adding pairs of multiplies +  // from even/odd elements, from zero_extend/sign_extend operands. +  //    // TODO: Handle constant vectors and use knownbits/computenumsignbits? -  // Canonicalize zero_extend to LHS. -  if (N01.getOpcode() == ISD::ZERO_EXTEND) -    std::swap(N00, N01); -  if (N11.getOpcode() == ISD::ZERO_EXTEND) -    std::swap(N10, N11); - -  // Ensure we have a zero_extend and a sign_extend. -  if (N00.getOpcode() != ISD::ZERO_EXTEND || -      N01.getOpcode() != ISD::SIGN_EXTEND || -      N10.getOpcode() != ISD::ZERO_EXTEND || -      N11.getOpcode() != ISD::SIGN_EXTEND) +  SDValue N00, N01, N10, N11; +  if (!sd_match(SSatVal, +                m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))), +                      m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11))))))      return SDValue(); -  // Peek through the extends. -  N00 = N00.getOperand(0); -  N01 = N01.getOperand(0); -  N10 = N10.getOperand(0); -  N11 = N11.getOperand(0); -    // Ensure the extend is from vXi8.    if (N00.getValueType().getVectorElementType() != MVT::i8 ||        N01.getValueType().getVectorElementType() != MVT::i8 || @@ -54768,9 +54614,11 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,        KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);        // Check the shift amount is byte aligned.        // Check the truncation doesn't use any shifted in (zero) top bits. +      // Check the shift amount doesn't depend on the original load.        if (KnownAmt.countMinTrailingZeros() >= 3 &&            KnownAmt.getMaxValue().ule(SrcVT.getSizeInBits() - -                                     VT.getSizeInBits())) { +                                     VT.getSizeInBits()) && +          !Ld->isPredecessorOf(ShAmt.getNode())) {          EVT PtrVT = Ld->getBasePtr().getValueType();          SDValue PtrBitOfs = DAG.getZExtOrTrunc(ShAmt, DL, PtrVT);          SDValue PtrByteOfs = diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td index 69a5115..522782a 100644 --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -338,188 +338,6 @@ let Predicates = [HasAMXFP8, In64BitMode] in {    }  } -let Predicates = [HasAMXTILE, In64BitMode], isPseudo = true, SchedRW = [WriteSystem] in { -  let mayStore = 1 in -  def PTILEPAIRSTORE : PseudoI<(outs), (ins opaquemem:$src1, TILEPair:$src2), []>; -  let mayLoad = 1 in -  def PTILEPAIRLOAD : PseudoI<(outs TILEPair:$dst), (ins opaquemem:$src), []>; -} - -multiclass T2RPNTLVW_Base<bits<8> op1, bits<8> op2, string rs, string suffix> { -  def Z0#rs#suffix    : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz0" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PS; -  def Z0#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz0" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PS; -  def Z1#rs#suffix    : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz1" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PD; -  def Z1#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz1" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PD; -} - -let Predicates = [HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "">, T8, VEX; - -let Predicates = [HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "_EVEX">, T8, EVEX, NoCD8; - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "">, T_MAP5, VEX; - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "_EVEX">, T_MAP5, EVEX, NoCD8; - -let Predicates = [HasAMXTRANSPOSE, In64BitMode] in { -  let SchedRW = [WriteSystem] in { -    def TTRANSPOSED : I<0x5f, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), -                        "ttransposed\t{$src, $dst|$dst, $src}", []>, VEX, T8, XS; -    let isPseudo = true in { -      def PT2RPNTLVWZ0V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ0T1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ1T1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -    } - -    def PTTRANSPOSEDV : PseudoI<(outs TILE:$dst), -                                (ins GR16:$src1, GR16:$src2, TILE:$src), -                                [(set TILE: $dst, -                                 (int_x86_ttransposed_internal GR16:$src1, GR16:$src2, -                                  TILE:$src))]>; - -    let usesCustomInserter = 1 in { -      def PT2RPNTLVWZ0 : PseudoI<(outs), (ins u8imm:$dst, -                                 sibmem:$src1), []>; -      def PT2RPNTLVWZ0T1 : PseudoI<(outs), (ins u8imm:$dst, -                                   sibmem:$src1), []>; -      def PT2RPNTLVWZ1 : PseudoI<(outs), (ins u8imm:$dst, -                                 sibmem:$src1), []>; -      def PT2RPNTLVWZ1T1 : PseudoI<(outs), (ins u8imm:$dst, -                                   sibmem:$src1), []>; -      def PTTRANSPOSED : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), -                                 [(int_x86_ttransposed timm:$dst, timm:$src)]>; -    } -  } -} // HasAMXTILE, HasAMXTRANSPOSE - -let Predicates = [HasAMXBF16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in -    def TTDPBF16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), -                       (ins TILE:$src1, TILE:$src2, TILE:$src3), -                       "ttdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                       []>, VEX, VVVV, T8,XS; -  let Constraints = "$src4 = $dst" in -    def PTTDPBF16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                GR16:$src2, GR16:$src3, TILE:$src4, -                                TILE:$src5, TILE:$src6), -                                [(set TILE: $dst, -                                  (int_x86_ttdpbf16ps_internal GR16:$src1, GR16:$src2, -                                   GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  let usesCustomInserter = 1 in -    def PTTDPBF16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                              [(int_x86_ttdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>; -} - -let Predicates = [HasAMXFP16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in -    def TTDPFP16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), -                       (ins TILE:$src1, TILE:$src2, TILE:$src3), -                       "ttdpfp16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                       []>, VEX, VVVV, T8,XD; -  let Constraints = "$src4 = $dst" in -    def PTTDPFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                GR16:$src2, GR16:$src3, TILE:$src4, -                                TILE:$src5, TILE:$src6), -                                [(set TILE: $dst, -                                  (int_x86_ttdpfp16ps_internal GR16:$src1, GR16:$src2, -                                   GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  let usesCustomInserter = 1 in -    def PTTDPFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                              [(int_x86_ttdpfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -} - -let Predicates = [HasAMXCOMPLEX, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in { -    def TTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                          (ins TILE:$src1, TILE:$src2, TILE:$src3), -                          "ttcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                          []>, VEX, VVVV, T8,XD; -    def TTCMMRLFP16PS: I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                         (ins TILE:$src1, TILE:$src2, TILE:$src3), -                         "ttcmmrlfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                         []>, VEX, VVVV, T8,XS; -    def TCONJTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                          (ins TILE:$src1, TILE:$src2, TILE:$src3), -                          "tconjtcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                          []>, VEX, VVVV, WIG, T8,PS; -  } -  def TCONJTFP16 : I<0x6b, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), -                     "tconjtfp16\t{$src, $dst|$dst, $src}", []>, VEX, T8,PD; - -  let Constraints = "$src4 = $dst" in { -    def PTTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                  GR16:$src2, GR16:$src3, TILE:$src4, -                                  TILE:$src5, TILE:$src6), -                                  [(set TILE: $dst, -                                    (int_x86_ttcmmimfp16ps_internal GR16:$src1, GR16:$src2, -                                     GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -    def PTTCMMRLFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                  GR16:$src2, GR16:$src3, TILE:$src4, -                                  TILE:$src5, TILE:$src6), -                                  [(set TILE: $dst, -                                    (int_x86_ttcmmrlfp16ps_internal GR16:$src1, GR16:$src2, -                                     GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -    def PTCONJTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                      GR16:$src2, GR16:$src3, TILE:$src4, -                                      TILE:$src5, TILE:$src6), -                                      [(set TILE: $dst, -                                        (int_x86_tconjtcmmimfp16ps_internal GR16:$src1, GR16:$src2, -                                         GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  } -  def PTCONJTFP16V : PseudoI<(outs TILE:$dst), (ins GR16:$src1, GR16:$src2, TILE:$src3), -                             [(set TILE: $dst, (int_x86_tconjtfp16_internal GR16:$src1, GR16:$src2, TILE:$src3))]>; - -  let usesCustomInserter = 1 in { -    def PTTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                 [(int_x86_ttcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTTCMMRLFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                 [(int_x86_ttcmmrlfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTCONJTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                     [(int_x86_tconjtcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTCONJTFP16 : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), -                              [(int_x86_tconjtfp16 timm:$dst, timm:$src)]>; -  } -} - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let isPseudo = true in { -    def PT2RPNTLVWZ0RSV   : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ0RST1V : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ1RSV   : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ1RST1V : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -  } -  let  usesCustomInserter = 1 in { -    def PT2RPNTLVWZ0RS   : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ0RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ1RS   : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ1RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -  } -} // HasAMXMOVRS, HasAMXTRANSPOSE -  multiclass TILELOADDRS_Base<string suffix> {    def suffix    : I<0x4a, MRMSrcMemFSIB, (outs TILE:$dst), (ins sibmem:$src1),                      "tileloaddrs\t{$src1, $dst|$dst, $src1}", []>, T8, XD; @@ -721,29 +539,3 @@ let Predicates = [HasAMXTF32, In64BitMode] in {      }    } // SchedRW = [WriteSystem]  } // HasAMXTF32 - -let Predicates = [HasAMXTF32, HasAMXTRANSPOSE, In64BitMode] in { -  let SchedRW = [WriteSystem] in { -    let Constraints = "$src1 = $dst" in { -      def TTMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst), -                         (ins TILE:$src1, TILE:$src2, TILE:$src3), -                         "ttmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                         []>, VEX, VVVV, T8, PS; -    } -    let Constraints = "$src4 = $dst" in { -      def PTTMMULTF32PSV : PseudoI<(outs TILE:$dst), -                                   (ins GR16:$src1, GR16:$src2, GR16:$src3, -                                    TILE:$src4, TILE:$src5, TILE:$src6), -                                   [(set TILE:$dst, -                                     (int_x86_ttmmultf32ps_internal GR16:$src1, -                                      GR16:$src2, GR16:$src3, TILE:$src4, -                                      TILE:$src5, TILE:$src6))]>; -    } -    let usesCustomInserter = 1 in { -      def PTTMMULTF32PS : PseudoI<(outs), -                                  (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                  [(int_x86_ttmmultf32ps timm:$src1, timm:$src2, -                                    timm:$src3)]>; -    } -  } // SchedRW = [WriteSystem] -} // HasAMXTF32, HasAMXTRANSPOSE diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index 5c23f91..6b2a7a4 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -4544,11 +4544,6 @@ static unsigned getLoadStoreRegOpcode(Register Reg,      return Load ? GET_EGPR_IF_ENABLED(X86::TILELOADD)                  : GET_EGPR_IF_ENABLED(X86::TILESTORED);  #undef GET_EGPR_IF_ENABLED -  case 2048: -    assert(X86::TILEPAIRRegClass.hasSubClassEq(RC) && -           "Unknown 2048-byte regclass"); -    assert(STI.hasAMXTILE() && "Using 2048-bit register requires AMX-TILE"); -    return Load ? X86::PTILEPAIRLOAD : X86::PTILEPAIRSTORE;    }  } @@ -4743,8 +4738,6 @@ static bool isAMXOpcode(unsigned Opc) {    case X86::TILESTORED:    case X86::TILELOADD_EVEX:    case X86::TILESTORED_EVEX: -  case X86::PTILEPAIRLOAD: -  case X86::PTILEPAIRSTORE:      return true;    }  } @@ -4757,8 +4750,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB,    default:      llvm_unreachable("Unexpected special opcode!");    case X86::TILESTORED: -  case X86::TILESTORED_EVEX: -  case X86::PTILEPAIRSTORE: { +  case X86::TILESTORED_EVEX: {      // tilestored %tmm, (%sp, %idx)      MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();      Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); @@ -4772,8 +4764,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB,      break;    }    case X86::TILELOADD: -  case X86::TILELOADD_EVEX: -  case X86::PTILEPAIRLOAD: { +  case X86::TILELOADD_EVEX: {      // tileloadd (%sp, %idx), %tmm      MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();      Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); diff --git a/llvm/lib/Target/X86/X86InstrOperands.td b/llvm/lib/Target/X86/X86InstrOperands.td index 5207eca..6ba07f7 100644 --- a/llvm/lib/Target/X86/X86InstrOperands.td +++ b/llvm/lib/Target/X86/X86InstrOperands.td @@ -536,10 +536,3 @@ def VK8Pair : RegisterOperand<VK8PAIR, "printVKPair"> {  def VK16Pair : RegisterOperand<VK16PAIR, "printVKPair"> {    let ParserMatchClass = VK16PairAsmOperand;  } - -let RenderMethod = "addTILEPairOperands" in -  def TILEPairAsmOperand : AsmOperandClass { let Name = "TILEPair"; } - -def TILEPair : RegisterOperand<TILEPAIR, "printTILEPair"> { -  let ParserMatchClass = TILEPairAsmOperand; -} diff --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td index c20bb05..98104a6f 100644 --- a/llvm/lib/Target/X86/X86InstrPredicates.td +++ b/llvm/lib/Target/X86/X86InstrPredicates.td @@ -183,7 +183,6 @@ def HasAMXINT8   : Predicate<"Subtarget->hasAMXINT8()">;  def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">;  def HasAMXFP8    : Predicate<"Subtarget->hasAMXFP8()">;  def HasAMXMOVRS  : Predicate<"Subtarget->hasAMXMOVRS()">; -def HasAMXTRANSPOSE : Predicate<"Subtarget->hasAMXTRANSPOSE()">;  def HasAMXAVX512 : Predicate<"Subtarget->hasAMXAVX512()">;  def HasAMXTF32   : Predicate<"Subtarget->hasAMXTF32()">;  def HasUINTR     : Predicate<"Subtarget->hasUINTR()">; diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 8ffd454..2fc5d38 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -74,22 +74,6 @@ static bool isAMXCast(Instruction *II) {           match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));  } -// Some instructions may return more than one tiles. -// e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal -static unsigned getNumDefTiles(IntrinsicInst *II) { -  Type *Ty = II->getType(); -  if (Ty->isX86_AMXTy()) -    return 1; - -  unsigned Num = 0; -  for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) { -    Type *STy = Ty->getContainedType(i); -    if (STy->isX86_AMXTy()) -      Num++; -  } -  return Num; -} -  static bool isAMXIntrinsic(Value *I) {    auto *II = dyn_cast<IntrinsicInst>(I);    if (!II) @@ -98,7 +82,7 @@ static bool isAMXIntrinsic(Value *I) {      return false;    // Check if return type or parameter is x86_amx. If it is x86_amx    // the intrinsic must be x86 amx intrinsics. -  if (getNumDefTiles(II) > 0) +  if (II->getType()->isX86_AMXTy())      return true;    for (Value *V : II->args()) {      if (V->getType()->isX86_AMXTy()) @@ -137,27 +121,7 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {    llvm_unreachable("No terminator in the entry block!");  } -class ShapeCalculator { -private: -  const TargetMachine *TM = nullptr; - -  // In AMX intrinsics we let Shape = {Row, Col}, but the -  // RealCol = Col / ElementSize. We may use the RealCol -  // as a new Row for other new created AMX intrinsics. -  std::map<Value *, Value *> Col2Row, Row2Col; - -public: -  ShapeCalculator(const TargetMachine *TargetM) : TM(TargetM) {} -  std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); -  std::pair<Value *, Value *> getShape(PHINode *Phi); -  Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); -  Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity); -}; - -Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V, -                                      unsigned Granularity) { -  if (auto It = Col2Row.find(V); It != Col2Row.end()) -    return It->second; +static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) {    IRBuilder<> Builder(II);    Value *RealRow = nullptr;    if (isa<ConstantInt>(V)) @@ -186,47 +150,16 @@ Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V,          getFirstNonAllocaInTheEntryBlock(*II->getFunction()));      RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));    } -  Col2Row[V] = RealRow;    return RealRow;  } -Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V, -                                      unsigned Granularity) { -  if (auto It = Row2Col.find(V); It != Row2Col.end()) -    return It->second; -  IRBuilder<> Builder(II); -  Value *RealCol = nullptr; -  if (isa<ConstantInt>(V)) -    RealCol = -        Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity); -  else if (isa<Instruction>(V)) { -    Builder.SetInsertPoint(cast<Instruction>(V)); -    RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity)); -    cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V)); -  } else { -    // When it is not a const value and it is a function argument, we create -    // Row at the entry bb. -    IRBuilder<> NewBuilder( -        getFirstNonAllocaInTheEntryBlock(*II->getFunction())); -    RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity)); -  } -  Row2Col[V] = RealCol; -  return RealCol; -} -  // TODO: Refine the row and col-in-bytes of tile to row and col of matrix. -std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, -                                                      unsigned OpNo) { -  (void)TM; +std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {    IRBuilder<> Builder(II);    Value *Row = nullptr, *Col = nullptr;    switch (II->getIntrinsicID()) {    default:      llvm_unreachable("Expect amx intrinsics"); -  case Intrinsic::x86_t2rpntlvwz0_internal: -  case Intrinsic::x86_t2rpntlvwz0t1_internal: -  case Intrinsic::x86_t2rpntlvwz1_internal: -  case Intrinsic::x86_t2rpntlvwz1t1_internal:    case Intrinsic::x86_tileloadd64_internal:    case Intrinsic::x86_tileloaddt164_internal:    case Intrinsic::x86_tilestored64_internal: @@ -271,13 +204,6 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,      }      break;    } -  case Intrinsic::x86_ttransposed_internal: -  case Intrinsic::x86_tconjtfp16_internal: { -    assert((OpNo == 2) && "Illegal Operand Number."); -    Row = getRowFromCol(II, II->getArgOperand(1), 4); -    Col = getColFromRow(II, II->getArgOperand(0), 4); -    break; -  }    case Intrinsic::x86_tcvtrowd2ps_internal:    case Intrinsic::x86_tcvtrowps2bf16h_internal:    case Intrinsic::x86_tcvtrowps2bf16l_internal: @@ -289,34 +215,12 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,      Col = II->getArgOperand(1);      break;    } -  case Intrinsic::x86_ttdpbf16ps_internal: -  case Intrinsic::x86_ttdpfp16ps_internal: -  case Intrinsic::x86_ttcmmimfp16ps_internal: -  case Intrinsic::x86_ttcmmrlfp16ps_internal: -  case Intrinsic::x86_tconjtcmmimfp16ps_internal: -  case Intrinsic::x86_ttmmultf32ps_internal: { -    switch (OpNo) { -    case 3: -      Row = II->getArgOperand(0); -      Col = II->getArgOperand(1); -      break; -    case 4: -      Row = getRowFromCol(II, II->getArgOperand(2), 4); -      Col = getColFromRow(II, II->getArgOperand(0), 4); -      break; -    case 5: -      Row = getRowFromCol(II, II->getArgOperand(2), 4); -      Col = II->getArgOperand(1); -      break; -    } -    break; -  }    }    return std::make_pair(Row, Col);  } -std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) { +static std::pair<Value *, Value *> getShape(PHINode *Phi) {    Use &U = *(Phi->use_begin());    unsigned OpNo = U.getOperandNo();    User *V = U.getUser(); @@ -349,15 +253,14 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) {  namespace {  class X86LowerAMXType {    Function &Func; -  ShapeCalculator *SC;    // In AMX intrinsics we let Shape = {Row, Col}, but the    // RealCol = Col / ElementSize. We may use the RealCol    // as a new Row for other new created AMX intrinsics. -  std::map<Value *, Value *> Col2Row, Row2Col; +  std::map<Value *, Value *> Col2Row;  public: -  X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {} +  X86LowerAMXType(Function &F) : Func(F) {}    bool visit();    void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);    void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); @@ -374,7 +277,7 @@ void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {    Use &U = *(Bitcast->use_begin());    unsigned OpNo = U.getOperandNo();    auto *II = cast<IntrinsicInst>(U.getUser()); -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(Bitcast);    // Use the maximun column as stride.    Value *Stride = Builder.getInt64(64); @@ -454,7 +357,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {      Builder.CreateStore(Src, AllocaAddr);      // TODO we can pick an constant operand for the shape.      Value *Row = nullptr, *Col = nullptr; -    std::tie(Row, Col) = SC->getShape(II, OpNo); +    std::tie(Row, Col) = getShape(II, OpNo);      std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};      Value *NewInst =          Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args); @@ -594,18 +497,11 @@ static Value *getAllocaPos(BasicBlock *BB) {  static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {    assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); -  auto *II = dyn_cast<IntrinsicInst>(TileDef); -  unsigned Idx = 0; -  // Extract tile from multiple tiles' def. -  if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) { -    assert(Extr->hasIndices() && "Tile extract miss index!"); -    Idx = Extr->getIndices()[0]; -    II = cast<IntrinsicInst>(Extr->getOperand(0)); -  } +  auto *II = cast<IntrinsicInst>(TileDef);    assert(II && "Not tile intrinsic!"); -  Value *Row = II->getOperand(Idx); -  Value *Col = II->getOperand(Idx + 1); +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1);    BasicBlock *BB = TileDef->getParent();    BasicBlock::iterator Iter = TileDef->getIterator(); @@ -624,20 +520,14 @@ static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {    // Get tile shape.    IntrinsicInst *II = nullptr; -  unsigned Idx = 0;    if (IsPHI) {      Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);      II = cast<IntrinsicInst>(PhiOp); -  } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) { -    // Extract tile from multiple tiles' def. -    assert(Extr->hasIndices() && "Tile extract miss index!"); -    Idx = Extr->getIndices()[0]; -    II = cast<IntrinsicInst>(Extr->getOperand(0));    } else {      II = cast<IntrinsicInst>(V);    } -  Value *Row = II->getOperand(Idx); -  Value *Col = II->getOperand(Idx + 1); +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1);    Instruction *UserI = cast<Instruction>(U.getUser());    IRBuilder<> Builder(UserI); @@ -848,12 +738,10 @@ namespace {  class X86LowerAMXCast {    Function &Func; -  ShapeCalculator *SC;    std::unique_ptr<DominatorTree> DT;  public: -  X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC) -      : Func(F), SC(ShapeC), DT(nullptr) {} +  X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}    bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);    bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);    bool combineTilezero(IntrinsicInst *Cast); @@ -932,7 +820,7 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(          if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())            return false;          Value *Row = nullptr, *Col = nullptr; -        std::tie(Row, Col) = SC->getShape(OldPN); +        std::tie(Row, Col) = getShape(OldPN);          // TODO: If it is not constant the Row and Col must domoniate tilezero          // that we are going to create.          if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) @@ -1063,19 +951,6 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(    return true;  } -static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx, -                                       bool IsRow) { -  if (!isAMXIntrinsic(Inst)) -    return nullptr; - -  auto *II = cast<IntrinsicInst>(Inst); -  if (IsRow) -    return II->getOperand(0); - -  assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!"); -  return II->getOperand(ShapeIdx + 1); -} -  // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)  // store <256 x i32> %43, <256 x i32>* %p, align 64  // --> @@ -1090,38 +965,13 @@ bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {    if (!Tile->hasOneUse())      return false; -  // We don't fetch shape from tilestore, we only get shape from tiledef, -  // so we can set the max tile shape to tilestore for special cases. +  auto *II = cast<IntrinsicInst>(Tile); +  // Tile is output from AMX intrinsic. The first operand of the +  // intrinsic is row, the second operand of the intrinsic is column. +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1); +    IRBuilder<> Builder(ST); -  Value *Row = nullptr; -  Value *Col = nullptr; - -  if (isAMXIntrinsic(Tile)) { -    auto *II = cast<IntrinsicInst>(Tile); -    // Tile is output from AMX intrinsic. The first operand of the -    // intrinsic is row, the second operand of the intrinsic is column. -    Row = II->getOperand(0); -    Col = II->getOperand(1); -  } else { -    // Now we supported multi-tiles value in structure, so we may get tile -    // from extracting multi-tiles structure. -    // For example: -    // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1, -    //      i16 %2, i16 %3, i8* %4, i64 %5) -    // %7 = extractvalue { x86_amx, x86_amx } %6, 0 -    // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7) -    // store <256 x i32> %8, <256 x i32>* %0, align 1024 -    // -    // TODO: Currently we only handle extractvalue case, enhance me for other -    // cases if possible. -    auto *II = cast<ExtractValueInst>(Tile); -    assert(II && "We meet unhandle source in fetching tile value!"); -    unsigned ShapeIdx = II->getIndices()[0]; -    Value *Tiles = II->getOperand(0); -    Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true); -    Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false); -  } -  assert(Row && Col && "Shape got failed!");    // Stride should be equal to col(measured by bytes)    Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); @@ -1146,7 +996,7 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {    // shape information through def-use chain.    if (!isAMXIntrinsic(II))      return false; -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(LD);    // Stride should be equal to col(measured by bytes)    Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); @@ -1189,7 +1039,7 @@ bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {    if (!isAMXIntrinsic(II))      return false; -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(Cast);    Value *NewInst = @@ -1384,7 +1234,7 @@ bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {      Builder.CreateStore(Src, AllocaAddr);      // TODO we can pick an constant operand for the shape.      Value *Row = nullptr, *Col = nullptr; -    std::tie(Row, Col) = SC->getShape(II, OpNo); +    std::tie(Row, Col) = getShape(II, OpNo);      std::array<Value *, 4> Args = {          Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};      Value *NewInst = @@ -1445,14 +1295,13 @@ bool lowerAmxType(Function &F, const TargetMachine *TM,      return false;    bool C = false; -  ShapeCalculator SC(TM); -  X86LowerAMXCast LAC(F, &SC); +  X86LowerAMXCast LAC(F);    C |= LAC.combineAMXcast(TLI);    // There might be remaining AMXcast after combineAMXcast and they should be    // handled elegantly.    C |= LAC.transformAllAMXCast(); -  X86LowerAMXType LAT(F, &SC); +  X86LowerAMXType LAT(F);    C |= LAT.visit();    // Prepare for fast register allocation at O0. diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp index 2a1c499..8a1d00d 100644 --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -141,15 +141,10 @@ class X86PreTileConfig : public MachineFunctionPass {      if (!MO.isReg() || !MO.getReg().isVirtual())        return false; -    unsigned Shapes = 0; -    if (MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) -      Shapes = 1; -    if (MRI->getRegClass(MO.getReg())->getID() == X86::TILEPAIRRegClassID) -      Shapes = 2; -    if (!Shapes) +    if (MRI->getRegClass(MO.getReg())->getID() != X86::TILERegClassID)        return false; -    collectShapeInfo(MI, Shapes); +    collectShapeInfo(MI);      return true;    } @@ -165,7 +160,7 @@ class X86PreTileConfig : public MachineFunctionPass {    }    /// Collect the shape def information for later use. -  void collectShapeInfo(MachineInstr &MI, unsigned Shapes); +  void collectShapeInfo(MachineInstr &MI);    /// Try to hoist shapes definded below AMX instructions.    bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { @@ -231,7 +226,7 @@ INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)  INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",                      "Tile Register Pre-configure", false, false) -void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) { +void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {    auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {      MIRef MIR(MI, MBB);      auto &Refs = ShapeBBs[MBB]; @@ -240,10 +235,8 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {        Refs.insert(I, MIR);    }; -  // All shapes have same row in multi-tile operand. -  SmallVector<Register, 8> WorkList; -  for (unsigned I = 1; I < Shapes + 2; ++I) -    WorkList.push_back(MI.getOperand(I).getReg()); +  SmallVector<Register, 8> WorkList( +      {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});    while (!WorkList.empty()) {      Register R = WorkList.pop_back_val();      MachineInstr *DefMI = MRI->getVRegDef(R); @@ -252,13 +245,6 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {      if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)        continue; -    // This happens when column = 0 in multi-tile operand. -    if (DefMI->getOpcode() == X86::COPY) { -      MachineInstr *MI = MRI->getVRegDef(DefMI->getOperand(1).getReg()); -      if (MI && MI->isMoveImmediate()) -        continue; -    } -      if (DefMI->isPHI()) {        for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)          if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp index 76979e3..72f3813 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -597,10 +597,6 @@ BitVector X86RegisterInfo::getReservedRegs(const MachineFunction &MF) const {        Reserved.set(*AI);    } -  // Reserve low half pair registers in case they are used by RA aggressively. -  Reserved.set(X86::TMM0_TMM1); -  Reserved.set(X86::TMM2_TMM3); -    assert(checkAllSuperRegsMarked(Reserved,                                   {X86::SIL, X86::DIL, X86::BPL, X86::SPL,                                    X86::SIH, X86::DIH, X86::BPH, X86::SPH})); @@ -621,7 +617,7 @@ unsigned X86RegisterInfo::getNumSupportedRegs(const MachineFunction &MF) const {    // and try to return the minimum number of registers supported by the target.    static_assert((X86::R15WH + 1 == X86::YMM0) && (X86::YMM15 + 1 == X86::K0) &&                      (X86::K6_K7 + 1 == X86::TMMCFG) && -                    (X86::TMM6_TMM7 + 1 == X86::R16) && +                    (X86::TMM7 + 1 == X86::R16) &&                      (X86::R31WH + 1 == X86::NUM_TARGET_REGS),                  "Register number may be incorrect"); @@ -694,8 +690,7 @@ bool X86RegisterInfo::isFixedRegister(const MachineFunction &MF,  }  bool X86RegisterInfo::isTileRegisterClass(const TargetRegisterClass *RC) const { -  return RC->getID() == X86::TILERegClassID || -         RC->getID() == X86::TILEPAIRRegClassID; +  return RC->getID() == X86::TILERegClassID;  }  void X86RegisterInfo::adjustStackMapLiveOutMask(uint32_t *Mask) const { @@ -1062,17 +1057,9 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,    case X86::PTDPFP16PSV:    case X86::PTCMMIMFP16PSV:    case X86::PTCMMRLFP16PSV: -  case X86::PTTRANSPOSEDV: -  case X86::PTTDPBF16PSV: -  case X86::PTTDPFP16PSV: -  case X86::PTTCMMIMFP16PSV: -  case X86::PTTCMMRLFP16PSV: -  case X86::PTCONJTCMMIMFP16PSV: -  case X86::PTCONJTFP16V:    case X86::PTILELOADDRSV:    case X86::PTILELOADDRST1V:    case X86::PTMMULTF32PSV: -  case X86::PTTMMULTF32PSV:    case X86::PTDPBF8PSV:    case X86::PTDPBHF8PSV:    case X86::PTDPHBF8PSV: @@ -1083,56 +1070,7 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,      VRM->assignVirt2Shape(VirtReg, Shape);      return Shape;    } -  case X86::PT2RPNTLVWZ0V: -  case X86::PT2RPNTLVWZ0T1V: -  case X86::PT2RPNTLVWZ1V: -  case X86::PT2RPNTLVWZ1T1V: -  case X86::PT2RPNTLVWZ0RSV: -  case X86::PT2RPNTLVWZ0RST1V: -  case X86::PT2RPNTLVWZ1RSV: -  case X86::PT2RPNTLVWZ1RST1V: { -    MachineOperand &MO1 = MI->getOperand(1); -    MachineOperand &MO2 = MI->getOperand(2); -    MachineOperand &MO3 = MI->getOperand(3); -    ShapeT Shape({&MO1, &MO2, &MO1, &MO3}, MRI); -    VRM->assignVirt2Shape(VirtReg, Shape); -    return Shape; -  } -  } -} - -static bool canHintShape(ShapeT &PhysShape, ShapeT &VirtShape) { -  unsigned PhysShapeNum = PhysShape.getShapeNum(); -  unsigned VirtShapeNum = VirtShape.getShapeNum(); - -  if (PhysShapeNum < VirtShapeNum) -    return false; - -  if (PhysShapeNum == VirtShapeNum) { -    if (PhysShapeNum == 1) -      return PhysShape == VirtShape; - -    for (unsigned I = 0; I < PhysShapeNum; I++) { -      ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I)); -      ShapeT VShape(VirtShape.getRow(I), VirtShape.getCol(I)); -      if (VShape != PShape) -        return false; -    } -    return true; -  } - -  // Hint subreg of mult-tile reg to single tile reg. -  if (VirtShapeNum == 1) { -    for (unsigned I = 0; I < PhysShapeNum; I++) { -      ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I)); -      if (VirtShape == PShape) -        return true; -    }    } - -  // Note: Currently we have no requirement for case of -  // (VirtShapeNum > 1 and PhysShapeNum > VirtShapeNum) -  return false;  }  bool X86RegisterInfo::getRegAllocationHints(Register VirtReg, @@ -1153,7 +1091,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,    if (!VRM)      return BaseImplRetVal; -  if (ID != X86::TILERegClassID && ID != X86::TILEPAIRRegClassID) { +  if (ID != X86::TILERegClassID) {      if (DisableRegAllocNDDHints || !ST.hasNDD() ||          !TRI.isGeneralPurposeRegisterClass(&RC))        return BaseImplRetVal; @@ -1204,7 +1142,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,        return;      }      ShapeT PhysShape = getTileShape(VReg, const_cast<VirtRegMap *>(VRM), MRI); -    if (canHintShape(PhysShape, VirtShape)) +    if (PhysShape == VirtShape)        Hints.push_back(PhysReg);    }; diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td index 99b7910..692e42a 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.td +++ b/llvm/lib/Target/X86/X86RegisterInfo.td @@ -30,8 +30,6 @@ let Namespace = "X86" in {    def sub_ymm      : SubRegIndex<256>;    def sub_mask_0   : SubRegIndex<-1>;    def sub_mask_1   : SubRegIndex<-1, -1>; -  def sub_t0       : SubRegIndex<8192>; -  def sub_t1       : SubRegIndex<8192, 8192>;  }  //===----------------------------------------------------------------------===// @@ -432,10 +430,6 @@ def TMM4:  X86Reg<"tmm4",   4>;  def TMM5:  X86Reg<"tmm5",   5>;  def TMM6:  X86Reg<"tmm6",   6>;  def TMM7:  X86Reg<"tmm7",   7>; -// TMM register pairs -def TPAIRS : RegisterTuples<[sub_t0, sub_t1], -                            [(add TMM0, TMM2, TMM4, TMM6), -                             (add TMM1, TMM3, TMM5, TMM7)]>;  }  // Floating point stack registers. These don't map one-to-one to the FP @@ -862,9 +856,6 @@ def VK64WM  : RegisterClass<"X86", [v64i1], 64, (add VK32WM)> {let Size = 64;}  let CopyCost = -1 in // Don't allow copying of tile registers  def TILE : RegisterClass<"X86", [x86amx], 8192,                           (sequence "TMM%u", 0, 7)> {let Size = 8192;} -// Need check alignment 3rd operand size=1024*2*8 -let isAllocatable = 1 in -def TILEPAIR : RegisterClass<"X86", [untyped], 512, (add TPAIRS)> {let Size = 16384;}  //===----------------------------------------------------------------------===//  // Register categories. diff --git a/llvm/lib/Target/X86/X86TileConfig.cpp b/llvm/lib/Target/X86/X86TileConfig.cpp index 17a44dd..09ef8fb 100644 --- a/llvm/lib/Target/X86/X86TileConfig.cpp +++ b/llvm/lib/Target/X86/X86TileConfig.cpp @@ -74,63 +74,6 @@ INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)  INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,                      false) -unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) { -  if (Reg.isVirtual()) { -    unsigned RegClassID = MRI->getRegClass(Reg)->getID(); -    if (RegClassID == X86::TILERegClassID) -      return 1; -    if (RegClassID == X86::TILEPAIRRegClassID) -      return 2; -  } else { -    if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; -  } -  return 0; -} - -static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM, -                                 Register VirtReg, -                                 SmallVector<ShapeT, 8> &Phys2Shapes) { -  unsigned Num = getAMXRegNum(MRI, VirtReg); -  MCRegister PhysReg = VRM.getPhys(VirtReg); -  if (!PhysReg) -    return; - -  if (Num == 1) { -    unsigned Index = PhysReg - X86::TMM0; -    if (!Phys2Shapes[Index].isValid()) { -      ShapeT Shape = VRM.getShape(VirtReg); -      Phys2Shapes[Index] = std::move(Shape); -      return; -    } -  } -  // Split tile pair shape info to 2 single tile shape info. e.g: -  // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes. -  if (Num == 2) { -    unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2; -    unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1; - -    ShapeT Shape = VRM.getShape(VirtReg); -    assert(Shape.getShapeNum() == 2 && "Unexpected shape number!"); - -    if (!Phys2Shapes[Index0].isValid()) { -      ShapeT Shape0(Shape.getRow(0), Shape.getCol(0), MRI); -      Phys2Shapes[Index0] = std::move(Shape0); -    } - -    if (!Phys2Shapes[Index1].isValid()) { -      ShapeT Shape1(Shape.getRow(1), Shape.getCol(1), MRI); -      Phys2Shapes[Index1] = std::move(Shape1); -    } -  } -} - -static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) { -  return getAMXRegNum(MRI, Reg) > 0; -} -  bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {    X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();    // Early exit in the common case of non-AMX code. @@ -138,7 +81,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {      return false;    const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); -  const TargetRegisterInfo *TRI = ST.getRegisterInfo(); +  const X86RegisterInfo *TRI = ST.getRegisterInfo();    const TargetInstrInfo *TII = ST.getInstrInfo();    MachineRegisterInfo &MRI = MF.getRegInfo();    LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); @@ -176,24 +119,29 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {    assert(ConstMI && "Cannot find an insertion point");    unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs(); -  SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT()); +  SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);    for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {      Register VirtReg = Register::index2VirtReg(I);      if (MRI.reg_nodbg_empty(VirtReg))        continue; -    if (!isAMXRegClass(&MRI, VirtReg)) +    if (!TRI->isTileRegisterClass(MRI.getRegClass(VirtReg))) +      continue; +    MCRegister PhysReg = VRM.getPhys(VirtReg); +    if (!PhysReg)        continue; -    collectVirtRegShapes(&MRI, VRM, VirtReg, Phys2Shapes); +    unsigned Index = PhysReg - X86::TMM0; +    if (!Phys2Virt[Index]) +      Phys2Virt[Index] = VirtReg;    }    // Fill in the shape of each tile physical register.    for (unsigned I = 0; I < AMXRegNum; ++I) { -    ShapeT Shape = Phys2Shapes[I]; -    if (!Shape.isValid()) +    if (!Phys2Virt[I])        continue;      DebugLoc DL;      bool IsRow = true;      MachineInstr *NewMI = nullptr; +    ShapeT Shape = VRM.getShape(Phys2Virt[I]);      for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {        // Here is the data format for the tile config.        // 0      palette @@ -222,14 +170,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {                     "Cannot initialize with different shapes");              continue;            } -          if (DefMI.getOperand(1).isImm()) { -            Imm = DefMI.getOperand(1).getImm(); -          } else { -            assert(DefMI.getOpcode() == X86::MOV32r0 && -                   "The opcode is assumed to be MOV32r0 if the operand is not " -                   "immediate."); -            Imm = 0; -          } +          Imm = DefMI.getOperand(1).getImm();            NewMI = addFrameReference(                        BuildMI(MF.front(), ++ConstMI->getIterator(), DL, | 
