aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target')
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp7
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp21
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp22
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h2
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp13
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h5
-rw-r--r--llvm/lib/Target/Hexagon/HexagonPatternsHVX.td3
-rw-r--r--llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td1
-rw-r--r--llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td1
-rw-r--r--llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp10
-rw-r--r--llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td21
-rw-r--r--llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td33
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp129
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td105
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXSubtarget.h21
-rw-r--r--llvm/lib/Target/PowerPC/PPCTargetMachine.cpp5
-rw-r--r--llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp82
-rw-r--r--llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp10
-rw-r--r--llvm/lib/Target/RISCV/RISCVFeatures.td5
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfo.cpp19
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfo.td4
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoD.td1
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoF.td1
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td4
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp8
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp9
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp61
-rw-r--r--llvm/lib/Target/X86/AsmParser/X86Operand.h31
-rw-r--r--llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp5
-rw-r--r--llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h7
-rw-r--r--llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp19
-rw-r--r--llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h1
-rw-r--r--llvm/lib/Target/X86/X86.td6
-rw-r--r--llvm/lib/Target/X86/X86ExpandPseudo.cpp190
-rw-r--r--llvm/lib/Target/X86/X86FastPreTileConfig.cpp40
-rw-r--r--llvm/lib/Target/X86/X86FastTileConfig.cpp25
-rw-r--r--llvm/lib/Target/X86/X86ISelDAGToDAG.cpp78
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp256
-rw-r--r--llvm/lib/Target/X86/X86InstrAMX.td208
-rw-r--r--llvm/lib/Target/X86/X86InstrInfo.cpp13
-rw-r--r--llvm/lib/Target/X86/X86InstrOperands.td7
-rw-r--r--llvm/lib/Target/X86/X86InstrPredicates.td1
-rw-r--r--llvm/lib/Target/X86/X86LowerAMXType.cpp203
-rw-r--r--llvm/lib/Target/X86/X86PreTileConfig.cpp26
-rw-r--r--llvm/lib/Target/X86/X86RegisterInfo.cpp70
-rw-r--r--llvm/lib/Target/X86/X86RegisterInfo.td9
-rw-r--r--llvm/lib/Target/X86/X86TileConfig.cpp83
49 files changed, 463 insertions, 1420 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index fede586..47c1ac4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1032,6 +1032,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
}
break;
}
+ case Intrinsic::experimental_vector_extract_last_active:
+ if (ST->isSVEorStreamingSVEAvailable()) {
+ auto [LegalCost, _] = getTypeLegalizationCost(ICA.getArgTypes()[0]);
+ // This should turn into chained clastb instructions.
+ return LegalCost;
+ }
+ break;
default:
break;
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp
index e187959..907f830 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp
@@ -24,6 +24,7 @@
#include "llvm/CodeGen/GlobalISel/CSEInfo.h"
#include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineUniformityAnalysis.h"
@@ -34,9 +35,17 @@
using namespace llvm;
using namespace AMDGPU;
+using namespace llvm::MIPatternMatch;
namespace {
+// AMDGPU-specific pattern matchers
+template <typename SrcTy>
+inline UnaryOp_match<SrcTy, AMDGPU::G_AMDGPU_READANYLANE>
+m_GAMDGPUReadAnyLane(const SrcTy &Src) {
+ return UnaryOp_match<SrcTy, AMDGPU::G_AMDGPU_READANYLANE>(Src);
+}
+
class AMDGPURegBankLegalize : public MachineFunctionPass {
public:
static char ID;
@@ -160,10 +169,18 @@ AMDGPURegBankLegalizeCombiner::tryMatchRALFromUnmerge(Register Src) {
Register AMDGPURegBankLegalizeCombiner::getReadAnyLaneSrc(Register Src) {
// Src = G_AMDGPU_READANYLANE RALSrc
- auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_AMDGPU_READANYLANE);
- if (RAL)
+ Register RALSrc;
+ if (mi_match(Src, MRI, m_GAMDGPUReadAnyLane(m_Reg(RALSrc))))
return RALSrc;
+ // TruncSrc = G_AMDGPU_READANYLANE RALSrc
+ // AextSrc = G_TRUNC TruncSrc
+ // Src = G_ANYEXT AextSrc
+ if (mi_match(Src, MRI,
+ m_GAnyExt(m_GTrunc(m_GAMDGPUReadAnyLane(m_Reg(RALSrc)))))) {
+ return RALSrc;
+ }
+
// LoVgpr, HiVgpr = G_UNMERGE_VALUES UnmergeSrc
// LoSgpr = G_AMDGPU_READANYLANE LoVgpr
// HiSgpr = G_AMDGPU_READANYLANE HiVgpr
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
index b84c30e..dc8fa7f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
@@ -626,6 +626,23 @@ void RegBankLegalizeHelper::lowerSplitTo32(MachineInstr &MI) {
MI.eraseFromParent();
}
+void RegBankLegalizeHelper::lowerSplitTo16(MachineInstr &MI) {
+ Register Dst = MI.getOperand(0).getReg();
+ assert(MRI.getType(Dst) == V2S16);
+ auto [Op1Lo32, Op1Hi32] = unpackAExt(MI.getOperand(1).getReg());
+ auto [Op2Lo32, Op2Hi32] = unpackAExt(MI.getOperand(2).getReg());
+ unsigned Opc = MI.getOpcode();
+ auto Flags = MI.getFlags();
+ auto Op1Lo = B.buildTrunc(SgprRB_S16, Op1Lo32);
+ auto Op1Hi = B.buildTrunc(SgprRB_S16, Op1Hi32);
+ auto Op2Lo = B.buildTrunc(SgprRB_S16, Op2Lo32);
+ auto Op2Hi = B.buildTrunc(SgprRB_S16, Op2Hi32);
+ auto Lo = B.buildInstr(Opc, {SgprRB_S16}, {Op1Lo, Op2Lo}, Flags);
+ auto Hi = B.buildInstr(Opc, {SgprRB_S16}, {Op1Hi, Op2Hi}, Flags);
+ B.buildMergeLikeInstr(Dst, {Lo, Hi});
+ MI.eraseFromParent();
+}
+
void RegBankLegalizeHelper::lowerSplitTo32Select(MachineInstr &MI) {
Register Dst = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(Dst);
@@ -698,6 +715,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
return lowerUnpackBitShift(MI);
case UnpackMinMax:
return lowerUnpackMinMax(MI);
+ case ScalarizeToS16:
+ return lowerSplitTo16(MI);
case Ext32To64: {
const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());
MachineInstrBuilder Hi;
@@ -849,6 +868,7 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
return LLT::scalar(32);
case Sgpr64:
case Vgpr64:
+ case UniInVgprS64:
return LLT::scalar(64);
case Sgpr128:
case Vgpr128:
@@ -972,6 +992,7 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case UniInVcc:
case UniInVgprS16:
case UniInVgprS32:
+ case UniInVgprS64:
case UniInVgprV2S16:
case UniInVgprV4S32:
case UniInVgprB32:
@@ -1104,6 +1125,7 @@ void RegBankLegalizeHelper::applyMappingDst(
break;
}
case UniInVgprS32:
+ case UniInVgprS64:
case UniInVgprV2S16:
case UniInVgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h
index ad3ff1d..e7598f8 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h
@@ -72,6 +72,7 @@ class RegBankLegalizeHelper {
static constexpr LLT P6 = LLT::pointer(6, 32);
MachineRegisterInfo::VRegAttrs SgprRB_S32 = {SgprRB, S32};
+ MachineRegisterInfo::VRegAttrs SgprRB_S16 = {SgprRB, S16};
MachineRegisterInfo::VRegAttrs VgprRB_S32 = {VgprRB, S32};
MachineRegisterInfo::VRegAttrs VccRB_S1 = {VccRB, S1};
@@ -121,6 +122,7 @@ private:
void lowerV_BFE(MachineInstr &MI);
void lowerS_BFE(MachineInstr &MI);
void lowerSplitTo32(MachineInstr &MI);
+ void lowerSplitTo16(MachineInstr &MI);
void lowerSplitTo32Select(MachineInstr &MI);
void lowerSplitTo32SExtInReg(MachineInstr &MI);
void lowerUnpackMinMax(MachineInstr &MI);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
index 01abd35..b22e9bd 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
@@ -918,9 +918,20 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
bool hasSALUFloat = ST->hasSALUFloatInsts();
addRulesForGOpcs({G_FADD}, Standard)
+ .Uni(S16, {{UniInVgprS16}, {Vgpr16, Vgpr16}}, !hasSALUFloat)
+ .Uni(S16, {{Sgpr16}, {Sgpr16, Sgpr16}}, hasSALUFloat)
+ .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}}, hasSALUFloat)
.Uni(S32, {{UniInVgprS32}, {Vgpr32, Vgpr32}}, !hasSALUFloat)
- .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}});
+ .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
+ .Uni(S64, {{UniInVgprS64}, {Vgpr64, Vgpr64}})
+ .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr64}})
+ .Uni(V2S16, {{UniInVgprV2S16}, {VgprV2S16, VgprV2S16}}, !hasSALUFloat)
+ .Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, ScalarizeToS16},
+ hasSALUFloat)
+ .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
+ .Any({{UniV2S32}, {{UniInVgprV2S32}, {VgprV2S32, VgprV2S32}}})
+ .Any({{DivV2S32}, {{VgprV2S32}, {VgprV2S32, VgprV2S32}}});
addRulesForGOpcs({G_FPTOUI})
.Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}}, hasSALUFloat)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
index 030bd75..e6df5d8 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
@@ -92,8 +92,10 @@ enum UniformityLLTOpPredicateID {
V4S32,
UniV2S16,
+ UniV2S32,
DivV2S16,
+ DivV2S32,
// B types
B32,
@@ -178,7 +180,9 @@ enum RegBankLLTMappingApplyID {
UniInVcc,
UniInVgprS16,
UniInVgprS32,
+ UniInVgprS64,
UniInVgprV2S16,
+ UniInVgprV2S32,
UniInVgprV4S32,
UniInVgprB32,
UniInVgprB64,
@@ -217,6 +221,7 @@ enum LoweringMethodID {
V_BFE,
VgprToVccCopy,
SplitTo32,
+ ScalarizeToS16,
SplitTo32Select,
SplitTo32SExtInReg,
Ext32To64,
diff --git a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td
index 1637b91..d19920c 100644
--- a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td
+++ b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td
@@ -612,6 +612,9 @@ let Predicates = [UseHVX] in {
(V6_vandvrt HvxVR:$Vs, (ToI32 0x01010101))>;
def: Pat<(VecQ32 (trunc HVI32:$Vs)),
(V6_vandvrt HvxVR:$Vs, (ToI32 0x01010101))>;
+ def: Pat<(VecQ16 (trunc HWI32:$Vss)),
+ (Combineq(VecQ32(V6_vandvrt (HiVec $Vss), (ToI32 0x01010101))),
+ (VecQ32 (V6_vandvrt (LoVec $Vss), (ToI32 0x01010101))))>;
}
let Predicates = [UseHVX] in {
diff --git a/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td
index 690dd73..e86b21c 100644
--- a/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td
@@ -365,6 +365,7 @@ def : Pat<(f32 (uint_to_fp (i64 (sexti32 (i64 GPR:$src))))),
// FP Rounding
let Predicates = [HasBasicF, IsLA64] in {
def : PatFpr<frint, FRINT_S, FPR32>;
+def : PatFpr<flog2, FLOGB_S, FPR32>;
} // Predicates = [HasBasicF, IsLA64]
let Predicates = [HasBasicF, IsLA32] in {
diff --git a/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td
index daefbaa..2e88254 100644
--- a/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td
@@ -348,6 +348,7 @@ def : Pat<(bitconvert FPR64:$src), (MOVFR2GR_D FPR64:$src)>;
// FP Rounding
let Predicates = [HasBasicD, IsLA64] in {
def : PatFpr<frint, FRINT_D, FPR64>;
+def : PatFpr<flog2, FLOGB_D, FPR64>;
} // Predicates = [HasBasicD, IsLA64]
/// Pseudo-instructions needed for the soft-float ABI with LA32D
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 80c96c6..a6de839 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -244,8 +244,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP_TO_BF16, MVT::f32,
Subtarget.isSoftFPABI() ? LibCall : Custom);
- if (Subtarget.is64Bit())
+ if (Subtarget.is64Bit()) {
setOperationAction(ISD::FRINT, MVT::f32, Legal);
+ setOperationAction(ISD::FLOG2, MVT::f32, Legal);
+ }
if (!Subtarget.hasBasicD()) {
setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
@@ -291,8 +293,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP_TO_BF16, MVT::f64,
Subtarget.isSoftFPABI() ? LibCall : Custom);
- if (Subtarget.is64Bit())
+ if (Subtarget.is64Bit()) {
setOperationAction(ISD::FRINT, MVT::f64, Legal);
+ setOperationAction(ISD::FLOG2, MVT::f64, Legal);
+ }
}
// Set operations for 'LSX' feature.
@@ -362,6 +366,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FMA, VT, Legal);
setOperationAction(ISD::FSQRT, VT, Legal);
setOperationAction(ISD::FNEG, VT, Legal);
+ setOperationAction(ISD::FLOG2, VT, Legal);
setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,
ISD::SETUGE, ISD::SETUGT},
VT, Expand);
@@ -443,6 +448,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FMA, VT, Legal);
setOperationAction(ISD::FSQRT, VT, Legal);
setOperationAction(ISD::FNEG, VT, Legal);
+ setOperationAction(ISD::FLOG2, VT, Legal);
setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,
ISD::SETUGE, ISD::SETUGT},
VT, Expand);
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 613dea6..ca4ee5f 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -1593,6 +1593,9 @@ def : Pat<(fma_nsz (fneg v4f64:$xj), v4f64:$xk, v4f64:$xa),
// XVFSQRT_{S/D}
defm : PatXrF<fsqrt, "XVFSQRT">;
+// XVFLOGB_{S/D}
+defm : PatXrF<flog2, "XVFLOGB">;
+
// XVRECIP_{S/D}
def : Pat<(fdiv vsplatf32_fpimm_eq_1, v8f32:$xj),
(XVFRECIP_S v8f32:$xj)>;
@@ -2024,6 +2027,24 @@ def : Pat<(v4i32(fp_to_uint v4f64:$vj)),
(XVFTINTRZ_LU_D v4f64:$vj)),
sub_128)>;
+// XVAVG_{B/H/W/D/BU/HU/WU/DU}, XVAVGR_{B/H/W/D/BU/HU/WU/DU}
+defm : VAvgPat<sra, "XVAVG_B", v32i8>;
+defm : VAvgPat<sra, "XVAVG_H", v16i16>;
+defm : VAvgPat<sra, "XVAVG_W", v8i32>;
+defm : VAvgPat<sra, "XVAVG_D", v4i64>;
+defm : VAvgPat<srl, "XVAVG_BU", v32i8>;
+defm : VAvgPat<srl, "XVAVG_HU", v16i16>;
+defm : VAvgPat<srl, "XVAVG_WU", v8i32>;
+defm : VAvgPat<srl, "XVAVG_DU", v4i64>;
+defm : VAvgrPat<sra, "XVAVGR_B", v32i8>;
+defm : VAvgrPat<sra, "XVAVGR_H", v16i16>;
+defm : VAvgrPat<sra, "XVAVGR_W", v8i32>;
+defm : VAvgrPat<sra, "XVAVGR_D", v4i64>;
+defm : VAvgrPat<srl, "XVAVGR_BU", v32i8>;
+defm : VAvgrPat<srl, "XVAVGR_HU", v16i16>;
+defm : VAvgrPat<srl, "XVAVGR_WU", v8i32>;
+defm : VAvgrPat<srl, "XVAVGR_DU", v4i64>;
+
// abs
def : Pat<(abs v32i8:$xj), (XVSIGNCOV_B v32i8:$xj, v32i8:$xj)>;
def : Pat<(abs v16i16:$xj), (XVSIGNCOV_H v16i16:$xj, v16i16:$xj)>;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index 4619c6b..92402ba 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -1518,6 +1518,18 @@ multiclass InsertExtractPatV2<ValueType vecty, ValueType elemty> {
}
}
+multiclass VAvgPat<SDPatternOperator OpNode, string Inst, ValueType vt> {
+ def : Pat<(OpNode (vt (add vt:$vj, vt:$vk)), (vt (vsplat_imm_eq_1))),
+ (!cast<LAInst>(Inst) vt:$vj, vt:$vk)>;
+}
+
+multiclass VAvgrPat<SDPatternOperator OpNode, string Inst, ValueType vt> {
+ def : Pat<(OpNode (vt (add (vt (add vt:$vj, vt:$vk)),
+ (vt (vsplat_imm_eq_1)))),
+ (vt (vsplat_imm_eq_1))),
+ (!cast<LAInst>(Inst) vt:$vj, vt:$vk)>;
+}
+
let Predicates = [HasExtLSX] in {
// VADD_{B/H/W/D}
@@ -1783,6 +1795,9 @@ def : Pat<(fma_nsz (fneg v2f64:$vj), v2f64:$vk, v2f64:$va),
// VFSQRT_{S/D}
defm : PatVrF<fsqrt, "VFSQRT">;
+// VFLOGB_{S/D}
+defm : PatVrF<flog2, "VFLOGB">;
+
// VFRECIP_{S/D}
def : Pat<(fdiv vsplatf32_fpimm_eq_1, v4f32:$vj),
(VFRECIP_S v4f32:$vj)>;
@@ -2154,6 +2169,24 @@ def : Pat<(f32 f32imm_vldi:$in),
def : Pat<(f64 f64imm_vldi:$in),
(f64 (EXTRACT_SUBREG (VLDI (to_f64imm_vldi f64imm_vldi:$in)), sub_64))>;
+// VAVG_{B/H/W/D/BU/HU/WU/DU}, VAVGR_{B/H/W/D/BU/HU/WU/DU}
+defm : VAvgPat<sra, "VAVG_B", v16i8>;
+defm : VAvgPat<sra, "VAVG_H", v8i16>;
+defm : VAvgPat<sra, "VAVG_W", v4i32>;
+defm : VAvgPat<sra, "VAVG_D", v2i64>;
+defm : VAvgPat<srl, "VAVG_BU", v16i8>;
+defm : VAvgPat<srl, "VAVG_HU", v8i16>;
+defm : VAvgPat<srl, "VAVG_WU", v4i32>;
+defm : VAvgPat<srl, "VAVG_DU", v2i64>;
+defm : VAvgrPat<sra, "VAVGR_B", v16i8>;
+defm : VAvgrPat<sra, "VAVGR_H", v8i16>;
+defm : VAvgrPat<sra, "VAVGR_W", v4i32>;
+defm : VAvgrPat<sra, "VAVGR_D", v2i64>;
+defm : VAvgrPat<srl, "VAVGR_BU", v16i8>;
+defm : VAvgrPat<srl, "VAVGR_HU", v8i16>;
+defm : VAvgrPat<srl, "VAVGR_WU", v4i32>;
+defm : VAvgrPat<srl, "VAVGR_DU", v2i64>;
+
// abs
def : Pat<(abs v16i8:$vj), (VSIGNCOV_B v16i8:$vj, v16i8:$vj)>;
def : Pat<(abs v8i16:$vj), (VSIGNCOV_H v8i16:$vj, v8i16:$vj)>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 7e7ee75..c667a09 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1871,17 +1871,6 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
(is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, _CH)) \
: (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, )))
-#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode, is_mc, is_ch, is_s32) \
- [&]() -> auto { \
- if (is_mc && is_ch) \
- return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC_CH); \
- if (is_ch) \
- return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _CH); \
- if (is_mc) \
- return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC); \
- return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, ); \
- }()
-
static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim,
bool IsShared32,
bool IsCacheHint,
@@ -1925,112 +1914,6 @@ static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim,
}
}
-static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
- bool IsMultiCast,
- bool IsCacheHint, bool IsIm2Col) {
- if (IsIm2Col) {
- switch (Dim) {
- case 3:
- return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL, IsMultiCast,
- IsCacheHint, IsShared32);
- case 4:
- return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL, IsMultiCast,
- IsCacheHint, IsShared32);
- case 5:
- return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL, IsMultiCast,
- IsCacheHint, IsShared32);
- default:
- llvm_unreachable("Invalid Dimension in im2col mode for "
- "GetCpAsyncBulkTensorG2SOpcode.");
- }
- } else {
- switch (Dim) {
- case 1:
- return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE, IsMultiCast,
- IsCacheHint, IsShared32);
- case 2:
- return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE, IsMultiCast,
- IsCacheHint, IsShared32);
- case 3:
- return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE, IsMultiCast,
- IsCacheHint, IsShared32);
- case 4:
- return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE, IsMultiCast,
- IsCacheHint, IsShared32);
- case 5:
- return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE, IsMultiCast,
- IsCacheHint, IsShared32);
- default:
- llvm_unreachable(
- "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode.");
- }
- }
-}
-
-static size_t GetDimsFromIntrinsic(unsigned IID) {
- switch (IID) {
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d:
- return 3;
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d:
- return 4;
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d:
- return 5;
- default:
- llvm_unreachable("Invalid im2col intrinsic in GetDimsFromIntrinsic.");
- }
-}
-
-void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
- bool IsIm2Col) {
- // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
- // {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2}
- // multicast, cache_hint,
- // multicast_flag, cache_hint_flag, cta_group_flag}
- // NumOperands = {Chain, IID} + {Actual intrinsic args}
- // = {2} + {8 + dims + im2col_offsets}
- size_t NumOps = N->getNumOperands();
- size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
- : (NumOps - 10);
- // Offsets is always 'NumDims - 2' and only for im2col mode
- size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
- bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1;
- bool IsMultiCast = N->getConstantOperandVal(NumOps - 3) == 1;
- size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src}
- size_t MultiCastIdx = NumBaseArgs + 2; // for Chain and IID
-
- unsigned CTAGroupVal = N->getConstantOperandVal(NumOps - 1);
- if ((CTAGroupVal > 0) && !Subtarget->hasCpAsyncBulkTensorCTAGroupSupport())
- report_fatal_error(
- formatv("CpAsyncBulkTensorG2S cta_group::1/2 is not supported on sm_{}",
- Subtarget->getSmVersion()));
-
- SDLoc DL(N);
- SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs));
-
- // Push MultiCast operand, if available
- if (IsMultiCast)
- Ops.push_back(N->getOperand(MultiCastIdx));
-
- // Push CacheHint operand, if available
- if (IsCacheHint)
- Ops.push_back(N->getOperand(MultiCastIdx + 1));
-
- // Flag for CTA Group
- Ops.push_back(getI32Imm(CTAGroupVal, DL));
-
- // Finally, the chain operand
- Ops.push_back(N->getOperand(0));
-
- bool IsShared32 =
- CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
- unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode(
- NumDims, IsShared32, IsMultiCast, IsCacheHint, IsIm2Col);
- ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
-}
-
void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,
unsigned RedOp,
bool IsIm2Col) {
@@ -2175,18 +2058,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
switch (IID) {
default:
return false;
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d:
- SelectCpAsyncBulkTensorG2SCommon(N);
- return true;
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
- case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
- SelectCpAsyncBulkTensorG2SCommon(N, /*IsIm2Col=*/true);
- return true;
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d:
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d:
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index c912e70..1cb579b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -86,7 +86,6 @@ private:
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
void SelectV2I64toI128(SDNode *N);
void SelectI128toV2I64(SDNode *N);
- void SelectCpAsyncBulkTensorG2SCommon(SDNode *N, bool IsIm2Col = false);
void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,
bool IsIm2Col = false);
void SelectTcgen05Ld(SDNode *N, bool hasOffset = false);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index dfde0cc..b260221 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -139,7 +139,6 @@ def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
def hasTcgen05MMAScaleInputDImm : Predicate<"Subtarget->hasTcgen05MMAScaleInputDImm()">;
-def hasTMACTAGroupSupport : Predicate<"Subtarget->hasCpAsyncBulkTensorCTAGroupSupport()">;
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c923f0e..e8758aa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -599,75 +599,15 @@ class TMA_IM2COL_UTIL<int dim, string mode> {
string base_str = !interleave(!foreach(i, !range(offsets), "$im2col" # i), ", ");
}
-// From Global to Shared memory (G2S)
-class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> {
- string prefix = "cp.async.bulk.tensor";
- string dir = "shared::cluster.global";
- string completion = "mbarrier::complete_tx::bytes";
- string inst_name = prefix
- # "." # dim # "d"
- # "." # dir
- # "." # mode
- # "." # completion
- # !if(mc, ".multicast::cluster", "")
- # !if(ch, ".L2::cache_hint", "");
- string intr_name = "CP_ASYNC_BULK_TENSOR_G2S_"
- # dim # "D"
- # !if(is_shared32, "_SHARED32", "")
- # !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
-}
-
def CTAGroupFlags : Operand<i32> {
let PrintMethod = "printCTAGroup";
}
-multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> {
- defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
- defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
- defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
- defvar rc = !if(is_shared32, B32, B64);
-
- defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0);
- defvar im2col_dag = !if(!eq(mode, "im2col"),
- !dag(ins, !listsplat(B16, num_im2col), !foreach(i, !range(num_im2col), "im2col" # i)),
- (ins));
- defvar im2col_str = !interleave(!foreach(i, !range(num_im2col), "$im2col" # i), ", ");
- defvar im2col_asm_str = ", {{" # im2col_str # "}}";
-
- defvar asm_str = !if(!eq(mode, "im2col"),
- !strconcat(asm_str_default, im2col_asm_str), asm_str_default);
+def tma_cta_group_imm0 : TImmLeaf<i32, [{return Imm == 0;}]>;
+def tma_cta_group_imm_any : TImmLeaf<i32, [{return Imm >= 0;}]>;
- def "" : NVPTXInst<(outs),
- !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)),
- !strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";")>,
- Requires<[hasPTX<80>, hasSM<90>]>;
- def _MC : NVPTXInst<(outs),
- !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
- (ins B16:$mc, CTAGroupFlags:$cg)),
- !strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;")>,
- Requires<[hasPTX<80>, hasSM<90>]>;
- def _CH : NVPTXInst<(outs),
- !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
- (ins B64:$ch, CTAGroupFlags:$cg)),
- !strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;")>,
- Requires<[hasPTX<80>, hasSM<90>]>;
- def _MC_CH : NVPTXInst<(outs),
- !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
- (ins B16:$mc, B64:$ch, CTAGroupFlags:$cg)),
- !strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;")>,
- Requires<[hasPTX<80>, hasSM<90>]>;
-}
-
-foreach dim = [1, 2, 3, 4, 5] in {
- foreach shared32 = [true, false] in {
- foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
- defm G2S_STRINGS<dim, mode, 0, 0, shared32>.intr_name :
- CP_ASYNC_BULK_TENSOR_G2S_INTR<dim, shared32, mode>;
- }
- }
-}
-
-multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> {
+multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred,
+ TImmLeaf cta_group_type = tma_cta_group_imm_any> {
defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
defvar asm_str_base = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
@@ -697,10 +637,10 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []>
!setdagop(dims_dag, intr),
!setdagop(im2col_dag, intr),
(intr B16:$mc, B64:$ch));
- defvar intr_dag_no_hints = !con(intr_dag_base, (intr 0, 0, timm:$cg));
- defvar intr_dag_with_mc = !con(intr_dag_base, (intr -1, 0, timm:$cg));
- defvar intr_dag_with_ch = !con(intr_dag_base, (intr 0, -1, timm:$cg));
- defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, timm:$cg));
+ defvar intr_dag_no_hints = !con(intr_dag_base, (intr 0, 0, cta_group_type:$cg));
+ defvar intr_dag_with_mc = !con(intr_dag_base, (intr -1, 0, cta_group_type:$cg));
+ defvar intr_dag_with_ch = !con(intr_dag_base, (intr 0, -1, cta_group_type:$cg));
+ defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, cta_group_type:$cg));
def "" : NVPTXInst<(outs), ins_dag,
inst_name # asm_str # ";",
@@ -719,14 +659,30 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []>
[intr_dag_with_mc_ch]>,
Requires<pred>;
}
+
+foreach dim = 1...5 in {
+ defm TMA_G2S_TILE_CG0_ # dim # "D"
+ : TMA_TENSOR_G2S_INTR<dim, "tile", [hasPTX<80>, hasSM<90>],
+ tma_cta_group_imm0>;
+ defm TMA_G2S_TILE_ # dim # "D"
+ : TMA_TENSOR_G2S_INTR<dim, "tile",
+ [callSubtarget<"hasTMABlackwellSupport">]>;
+}
foreach dim = 3...5 in {
+ defm TMA_G2S_IM2COL_CG0_ # dim # "D"
+ : TMA_TENSOR_G2S_INTR<dim, "im2col", [hasPTX<80>, hasSM<90>],
+ tma_cta_group_imm0>;
+ defm TMA_G2S_IM2COL_ # dim # "D"
+ : TMA_TENSOR_G2S_INTR<dim, "im2col",
+ [callSubtarget<"hasTMABlackwellSupport">]>;
foreach mode = ["im2col_w", "im2col_w_128"] in {
defm TMA_G2S_ # !toupper(mode) # "_" # dim # "D"
- : TMA_TENSOR_G2S_INTR<dim, mode, [hasTMACTAGroupSupport]>;
+ : TMA_TENSOR_G2S_INTR<dim, mode,
+ [callSubtarget<"hasTMABlackwellSupport">]>;
}
}
defm TMA_G2S_TILE_GATHER4_2D : TMA_TENSOR_G2S_INTR<5, "tile_gather4",
- [hasTMACTAGroupSupport]>;
+ [callSubtarget<"hasTMABlackwellSupport">]>;
multiclass TMA_TENSOR_G2S_CTA_INTR<int dim, string mode, list<Predicate> pred = []> {
defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
@@ -784,7 +740,8 @@ foreach dim = 3...5 in {
: TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w", [hasPTX<86>, hasSM<100>]>;
defm TMA_G2S_CTA_IM2COL_W_128_ # dim # "D"
- : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", [hasTMACTAGroupSupport]>;
+ : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128",
+ [callSubtarget<"hasTMABlackwellSupport">]>;
}
defm TMA_G2S_CTA_TILE_GATHER4_2D : TMA_TENSOR_G2S_CTA_INTR<5, "tile_gather4",
[hasPTX<86>, hasSM<100>]>;
@@ -835,7 +792,7 @@ foreach dim = 1...5 in {
}
}
defm TMA_S2G_TILE_SCATTER4_2D : TMA_TENSOR_S2G_INTR<5, "tile_scatter4",
- [hasTMACTAGroupSupport]>;
+ [callSubtarget<"hasTMABlackwellSupport">]>;
def TMAReductionFlags : Operand<i32> {
let PrintMethod = "printTmaReductionMode";
@@ -930,11 +887,11 @@ foreach dim = 3...5 in {
foreach mode = ["im2col_w", "im2col_w_128"] in {
defvar suffix = !toupper(mode) # "_" # dim # "D";
defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode,
- [hasTMACTAGroupSupport]>;
+ [callSubtarget<"hasTMABlackwellSupport">]>;
}
}
defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4",
- [hasTMACTAGroupSupport]>;
+ [callSubtarget<"hasTMABlackwellSupport">]>;
//Prefetchu and Prefetch
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 194dbdc..021b1f6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -166,18 +166,15 @@ public:
// f32x2 instructions in Blackwell family
bool hasF32x2Instructions() const;
- // TMA G2S copy with cta_group::1/2 support
- bool hasCpAsyncBulkTensorCTAGroupSupport() const {
- // TODO: Update/tidy-up after the family-conditional support arrives
- switch (FullSmVersion) {
- case 1003:
- case 1013:
- return PTXVersion >= 86;
- case 1033:
- return PTXVersion >= 88;
- default:
- return false;
- }
+ // Checks support for following in TMA:
+ // - cta_group::1/2 support
+ // - im2col_w/w_128 mode support
+ // - tile_gather4 mode support
+ // - tile_scatter4 mode support
+ bool hasTMABlackwellSupport() const {
+ return hasPTXWithFamilySMs(90, {100, 110}) ||
+ hasPTXWithFamilySMs(88, {100, 101}) ||
+ hasPTXWithAccelSMs(86, {100, 101});
}
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
diff --git a/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp b/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp
index 000d296..4ff489d 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp
@@ -296,8 +296,9 @@ PPCTargetMachine::PPCTargetMachine(const Target &T, const Triple &TT,
std::optional<Reloc::Model> RM,
std::optional<CodeModel::Model> CM,
CodeGenOptLevel OL, bool JIT)
- : CodeGenTargetMachineImpl(T, TT.computeDataLayout(), TT, CPU,
- computeFSAdditions(FS, OL, TT), Options,
+ : CodeGenTargetMachineImpl(T,
+ TT.computeDataLayout(Options.MCOptions.ABIName),
+ TT, CPU, computeFSAdditions(FS, OL, TT), Options,
getEffectiveRelocModel(TT, RM),
getEffectivePPCCodeModel(TT, CM, JIT), OL),
TLOF(createTLOF(getTargetTriple())),
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 8198173..282cf5d 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -92,6 +92,10 @@ private:
void emitFence(AtomicOrdering FenceOrdering, SyncScope::ID FenceSSID,
MachineIRBuilder &MIB) const;
bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const;
+ void addVectorLoadStoreOperands(MachineInstr &I,
+ SmallVectorImpl<SrcOp> &SrcOps,
+ unsigned &CurOp, bool IsMasked,
+ bool IsStrided) const;
bool selectIntrinsicWithSideEffects(MachineInstr &I,
MachineIRBuilder &MIB) const;
@@ -716,6 +720,26 @@ static unsigned selectRegImmLoadStoreOp(unsigned GenericOpc, unsigned OpSize) {
return GenericOpc;
}
+void RISCVInstructionSelector::addVectorLoadStoreOperands(
+ MachineInstr &I, SmallVectorImpl<SrcOp> &SrcOps, unsigned &CurOp,
+ bool IsMasked, bool IsStrided) const {
+ // Base Pointer
+ auto PtrReg = I.getOperand(CurOp++).getReg();
+ SrcOps.push_back(PtrReg);
+
+ // Stride
+ if (IsStrided) {
+ auto StrideReg = I.getOperand(CurOp++).getReg();
+ SrcOps.push_back(StrideReg);
+ }
+
+ // Mask
+ if (IsMasked) {
+ auto MaskReg = I.getOperand(CurOp++).getReg();
+ SrcOps.push_back(MaskReg);
+ }
+}
+
bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(
MachineInstr &I, MachineIRBuilder &MIB) const {
// Find the intrinsic ID.
@@ -752,21 +776,7 @@ bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(
SrcOps.push_back(Register(RISCV::NoRegister));
}
- // Base Pointer
- auto PtrReg = I.getOperand(CurOp++).getReg();
- SrcOps.push_back(PtrReg);
-
- // Stride
- if (IsStrided) {
- auto StrideReg = I.getOperand(CurOp++).getReg();
- SrcOps.push_back(StrideReg);
- }
-
- // Mask
- if (IsMasked) {
- auto MaskReg = I.getOperand(CurOp++).getReg();
- SrcOps.push_back(MaskReg);
- }
+ addVectorLoadStoreOperands(I, SrcOps, CurOp, IsMasked, IsStrided);
RISCVVType::VLMUL LMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(VT));
const RISCV::VLEPseudo *P =
@@ -795,6 +805,48 @@ bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(
I.eraseFromParent();
return constrainSelectedInstRegOperands(*PseudoMI, TII, TRI, RBI);
}
+ case Intrinsic::riscv_vsm:
+ case Intrinsic::riscv_vse:
+ case Intrinsic::riscv_vse_mask:
+ case Intrinsic::riscv_vsse:
+ case Intrinsic::riscv_vsse_mask: {
+ bool IsMasked = IntrinID == Intrinsic::riscv_vse_mask ||
+ IntrinID == Intrinsic::riscv_vsse_mask;
+ bool IsStrided = IntrinID == Intrinsic::riscv_vsse ||
+ IntrinID == Intrinsic::riscv_vsse_mask;
+ LLT VT = MRI->getType(I.getOperand(1).getReg());
+ unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits());
+
+ // Sources
+ unsigned CurOp = 1;
+ SmallVector<SrcOp, 4> SrcOps; // Source registers.
+
+ // Store value
+ auto PassthruReg = I.getOperand(CurOp++).getReg();
+ SrcOps.push_back(PassthruReg);
+
+ addVectorLoadStoreOperands(I, SrcOps, CurOp, IsMasked, IsStrided);
+
+ RISCVVType::VLMUL LMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(VT));
+ const RISCV::VSEPseudo *P = RISCV::getVSEPseudo(
+ IsMasked, IsStrided, Log2SEW, static_cast<unsigned>(LMUL));
+
+ auto PseudoMI = MIB.buildInstr(P->Pseudo, {}, SrcOps);
+
+ // Select VL
+ auto VLOpFn = renderVLOp(I.getOperand(CurOp++));
+ for (auto &RenderFn : *VLOpFn)
+ RenderFn(PseudoMI);
+
+ // SEW
+ PseudoMI.addImm(Log2SEW);
+
+ // Memref
+ PseudoMI.cloneMemRefs(I);
+
+ I.eraseFromParent();
+ return constrainSelectedInstRegOperands(*PseudoMI, TII, TRI, RBI);
+ }
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
index 4105618..526675a 100644
--- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
@@ -127,6 +127,10 @@ bool RISCVExpandPseudo::expandMI(MachineBasicBlock &MBB,
case RISCV::PseudoCCAND:
case RISCV::PseudoCCOR:
case RISCV::PseudoCCXOR:
+ case RISCV::PseudoCCMAX:
+ case RISCV::PseudoCCMAXU:
+ case RISCV::PseudoCCMIN:
+ case RISCV::PseudoCCMINU:
case RISCV::PseudoCCADDW:
case RISCV::PseudoCCSUBW:
case RISCV::PseudoCCSLL:
@@ -217,6 +221,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,
.addImm(0);
} else {
unsigned NewOpc;
+ // clang-format off
switch (MI.getOpcode()) {
default:
llvm_unreachable("Unexpected opcode!");
@@ -228,6 +233,10 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,
case RISCV::PseudoCCAND: NewOpc = RISCV::AND; break;
case RISCV::PseudoCCOR: NewOpc = RISCV::OR; break;
case RISCV::PseudoCCXOR: NewOpc = RISCV::XOR; break;
+ case RISCV::PseudoCCMAX: NewOpc = RISCV::MAX; break;
+ case RISCV::PseudoCCMIN: NewOpc = RISCV::MIN; break;
+ case RISCV::PseudoCCMAXU: NewOpc = RISCV::MAXU; break;
+ case RISCV::PseudoCCMINU: NewOpc = RISCV::MINU; break;
case RISCV::PseudoCCADDI: NewOpc = RISCV::ADDI; break;
case RISCV::PseudoCCSLLI: NewOpc = RISCV::SLLI; break;
case RISCV::PseudoCCSRLI: NewOpc = RISCV::SRLI; break;
@@ -250,6 +259,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,
case RISCV::PseudoCCNDS_BFOS: NewOpc = RISCV::NDS_BFOS; break;
case RISCV::PseudoCCNDS_BFOZ: NewOpc = RISCV::NDS_BFOZ; break;
}
+ // clang-format on
if (NewOpc == RISCV::NDS_BFOZ || NewOpc == RISCV::NDS_BFOS) {
BuildMI(TrueBB, DL, TII->get(NewOpc), DestReg)
diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td
index b4556f6..cfee6ab 100644
--- a/llvm/lib/Target/RISCV/RISCVFeatures.td
+++ b/llvm/lib/Target/RISCV/RISCVFeatures.td
@@ -1851,6 +1851,11 @@ def TuneShortForwardBranchOpt
def HasShortForwardBranchOpt : Predicate<"Subtarget->hasShortForwardBranchOpt()">;
def NoShortForwardBranchOpt : Predicate<"!Subtarget->hasShortForwardBranchOpt()">;
+def TuneShortForwardBranchIMinMax
+ : SubtargetFeature<"short-forward-branch-i-minmax", "HasShortForwardBranchIMinMax",
+ "true", "Enable short forward branch optimization for min,max instructions in Zbb",
+ [TuneShortForwardBranchOpt]>;
+
// Some subtargets require a S2V transfer buffer to move scalars into vectors.
// FIXME: Forming .vx/.vf/.wx/.wf can reduce register pressure.
def TuneNoSinkSplatOperands
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 912b82d..c9df787 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -869,7 +869,7 @@ std::optional<unsigned> getFoldedOpcode(MachineFunction &MF, MachineInstr &MI,
}
}
-// This is the version used during inline spilling
+// This is the version used during InlineSpiller::spillAroundUses
MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl(
MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned> Ops,
MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS,
@@ -1699,6 +1699,10 @@ unsigned getPredicatedOpcode(unsigned Opcode) {
case RISCV::AND: return RISCV::PseudoCCAND;
case RISCV::OR: return RISCV::PseudoCCOR;
case RISCV::XOR: return RISCV::PseudoCCXOR;
+ case RISCV::MAX: return RISCV::PseudoCCMAX;
+ case RISCV::MAXU: return RISCV::PseudoCCMAXU;
+ case RISCV::MIN: return RISCV::PseudoCCMIN;
+ case RISCV::MINU: return RISCV::PseudoCCMINU;
case RISCV::ADDI: return RISCV::PseudoCCADDI;
case RISCV::SLLI: return RISCV::PseudoCCSLLI;
@@ -1735,7 +1739,8 @@ unsigned getPredicatedOpcode(unsigned Opcode) {
/// return the defining instruction.
static MachineInstr *canFoldAsPredicatedOp(Register Reg,
const MachineRegisterInfo &MRI,
- const TargetInstrInfo *TII) {
+ const TargetInstrInfo *TII,
+ const RISCVSubtarget &STI) {
if (!Reg.isVirtual())
return nullptr;
if (!MRI.hasOneNonDBGUse(Reg))
@@ -1743,6 +1748,12 @@ static MachineInstr *canFoldAsPredicatedOp(Register Reg,
MachineInstr *MI = MRI.getVRegDef(Reg);
if (!MI)
return nullptr;
+
+ if (!STI.hasShortForwardBranchIMinMax() &&
+ (MI->getOpcode() == RISCV::MAX || MI->getOpcode() == RISCV::MIN ||
+ MI->getOpcode() == RISCV::MINU || MI->getOpcode() == RISCV::MAXU))
+ return nullptr;
+
// Check if MI can be predicated and folded into the CCMOV.
if (getPredicatedOpcode(MI->getOpcode()) == RISCV::INSTRUCTION_LIST_END)
return nullptr;
@@ -1806,10 +1817,10 @@ RISCVInstrInfo::optimizeSelect(MachineInstr &MI,
MachineRegisterInfo &MRI = MI.getParent()->getParent()->getRegInfo();
MachineInstr *DefMI =
- canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this);
+ canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this, STI);
bool Invert = !DefMI;
if (!DefMI)
- DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this);
+ DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this, STI);
if (!DefMI)
return nullptr;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index 7c89686..9cb53fb 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -768,7 +768,7 @@ def BGE : BranchCC_rri<0b101, "bge">;
def BLTU : BranchCC_rri<0b110, "bltu">;
def BGEU : BranchCC_rri<0b111, "bgeu">;
-let IsSignExtendingOpW = 1 in {
+let IsSignExtendingOpW = 1, canFoldAsLoad = 1 in {
def LB : Load_ri<0b000, "lb">, Sched<[WriteLDB, ReadMemBase]>;
def LH : Load_ri<0b001, "lh">, Sched<[WriteLDH, ReadMemBase]>;
def LW : Load_ri<0b010, "lw">, Sched<[WriteLDW, ReadMemBase]>;
@@ -889,8 +889,10 @@ def CSRRCI : CSR_ii<0b111, "csrrci">;
/// RV64I instructions
let Predicates = [IsRV64] in {
+let canFoldAsLoad = 1 in {
def LWU : Load_ri<0b110, "lwu">, Sched<[WriteLDW, ReadMemBase]>;
def LD : Load_ri<0b011, "ld">, Sched<[WriteLDD, ReadMemBase]>;
+}
def SD : Store_rri<0b011, "sd">, Sched<[WriteSTD, ReadStoreData, ReadMemBase]>;
let IsSignExtendingOpW = 1 in {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
index afac37d..4ffe3e6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
@@ -71,6 +71,7 @@ defvar DExtsRV64 = [DExt, ZdinxExt];
//===----------------------------------------------------------------------===//
let Predicates = [HasStdExtD] in {
+let canFoldAsLoad = 1 in
def FLD : FPLoad_r<0b011, "fld", FPR64, WriteFLD64>;
// Operands for stores are in the order srcreg, base, offset rather than
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
index 6571d99..b30f8ec 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
@@ -330,6 +330,7 @@ class PseudoFROUND<DAGOperand Ty, ValueType vt, ValueType intvt = XLenVT>
//===----------------------------------------------------------------------===//
let Predicates = [HasStdExtF] in {
+let canFoldAsLoad = 1 in
def FLW : FPLoad_r<0b010, "flw", FPR32, WriteFLD32>;
// Operands for stores are in the order srcreg, base, offset rather than
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td
index 0114fbd..5a67a5a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td
@@ -106,6 +106,10 @@ def PseudoCCSRA : SFBALU_rr;
def PseudoCCAND : SFBALU_rr;
def PseudoCCOR : SFBALU_rr;
def PseudoCCXOR : SFBALU_rr;
+def PseudoCCMAX : SFBALU_rr;
+def PseudoCCMIN : SFBALU_rr;
+def PseudoCCMAXU : SFBALU_rr;
+def PseudoCCMINU : SFBALU_rr;
def PseudoCCADDI : SFBALU_ri;
def PseudoCCANDI : SFBALU_ri;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3fea21e..3f0424f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3151,6 +3151,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectInsertElt(ResVReg, ResType, I);
case Intrinsic::spv_gep:
return selectGEP(ResVReg, ResType, I);
+ case Intrinsic::spv_bitcast: {
+ Register OpReg = I.getOperand(2).getReg();
+ SPIRVType *OpType =
+ OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
+ if (!GR.isBitcastCompatible(ResType, OpType))
+ report_fatal_error("incompatible result and operand types in a bitcast");
+ return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast);
+ }
case Intrinsic::spv_unref_global:
case Intrinsic::spv_init_global: {
MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg());
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 6e444c9..65dffc7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// Returns the loaded value.
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
FixedVectorType *TargetType, Value *Source) {
- assert(TargetType->getNumElements() <= SourceType->getNumElements());
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
buildAssignType(B, SourceType, NewLoad);
Value *AssignValue = NewLoad;
if (TargetType->getElementType() != SourceType->getElementType()) {
+ const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+ [[maybe_unused]] TypeSize TargetTypeSize =
+ DL.getTypeSizeInBits(TargetType);
+ [[maybe_unused]] TypeSize SourceTypeSize =
+ DL.getTypeSizeInBits(SourceType);
+ assert(TargetTypeSize == SourceTypeSize);
AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
{TargetType, SourceType}, {NewLoad});
buildAssignType(B, TargetType, AssignValue);
+ return AssignValue;
}
+ assert(TargetType->getNumElements() < SourceType->getNumElements());
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
Mask[I] = I;
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index db6f2d6..d538009 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -192,31 +192,43 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
.addUse(OpReg);
}
-// We do instruction selections early instead of calling MIB.buildBitcast()
-// generating the general op code G_BITCAST. When MachineVerifier validates
-// G_BITCAST we see a check of a kind: if Source Type is equal to Destination
-// Type then report error "bitcast must change the type". This doesn't take into
-// account the notion of a typed pointer that is important for SPIR-V where a
-// user may and should use bitcast between pointers with different pointee types
-// (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast).
-// It's important for correct lowering in SPIR-V, because interpretation of the
-// data type is not left to instructions that utilize the pointer, but encoded
-// by the pointer declaration, and the SPIRV target can and must handle the
-// declaration and use of pointers that specify the type of data they point to.
-// It's not feasible to improve validation of G_BITCAST using just information
-// provided by low level types of source and destination. Therefore we don't
-// produce G_BITCAST as the general op code with semantics different from
-// OpBitcast, but rather lower to OpBitcast immediately. As for now, the only
-// difference would be that CombinerHelper couldn't transform known patterns
-// around G_BUILD_VECTOR. See discussion
-// in https://github.com/llvm/llvm-project/pull/110270 for even more context.
-static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+// We lower G_BITCAST to OpBitcast here to avoid a MachineVerifier error.
+// The verifier checks if the source and destination LLTs of a G_BITCAST are
+// different, but this check is too strict for SPIR-V's typed pointers, which
+// may have the same LLT but different SPIRVType (e.g. pointers to different
+// pointee types). By lowering to OpBitcast here, we bypass the verifier's
+// check. See discussion in https://github.com/llvm/llvm-project/pull/110270
+// for more context.
+//
+// We also handle the llvm.spv.bitcast intrinsic here. If the source and
+// destination SPIR-V types are the same, we lower it to a COPY to enable
+// further optimizations like copy propagation.
+static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
SmallVector<MachineInstr *, 16> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
+ if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
+ Register DstReg = MI.getOperand(0).getReg();
+ Register SrcReg = MI.getOperand(2).getReg();
+ SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg);
+ assert(
+ DstType &&
+ "Expected destination SPIR-V type to have been assigned already.");
+ SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg);
+ assert(SrcType &&
+ "Expected source SPIR-V type to have been assigned already.");
+ if (DstType == SrcType) {
+ MIB.setInsertPt(*MI.getParent(), MI);
+ MIB.buildCopy(DstReg, SrcReg);
+ ToErase.push_back(&MI);
+ continue;
+ }
+ }
+
if (MI.getOpcode() != TargetOpcode::G_BITCAST)
continue;
+
MIB.setInsertPt(*MI.getParent(), MI);
buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),
MI.getOperand(1).getReg());
@@ -237,16 +249,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
SmallVector<MachineInstr *, 10> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
- if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
- !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
+ if (!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
continue;
assert(MI.getOperand(2).isReg());
MIB.setInsertPt(*MI.getParent(), MI);
ToErase.push_back(&MI);
- if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
- MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
- continue;
- }
Register Def = MI.getOperand(0).getReg();
Register Source = MI.getOperand(2).getReg();
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
@@ -1089,7 +1096,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
removeImplicitFallthroughs(MF, MIB);
insertSpirvDecorations(MF, GR, MIB);
insertInlineAsm(MF, GR, ST, MIB);
- selectOpBitcasts(MF, GR, MIB);
+ lowerBitcasts(MF, GR, MIB);
return true;
}
diff --git a/llvm/lib/Target/X86/AsmParser/X86Operand.h b/llvm/lib/Target/X86/AsmParser/X86Operand.h
index 89ac53e..a922725 100644
--- a/llvm/lib/Target/X86/AsmParser/X86Operand.h
+++ b/llvm/lib/Target/X86/AsmParser/X86Operand.h
@@ -620,37 +620,6 @@ struct X86Operand final : public MCParsedAsmOperand {
Inst.addOperand(MCOperand::createReg(Reg));
}
- bool isTILEPair() const {
- return Kind == Register &&
- X86MCRegisterClasses[X86::TILERegClassID].contains(getReg());
- }
-
- void addTILEPairOperands(MCInst &Inst, unsigned N) const {
- assert(N == 1 && "Invalid number of operands!");
- MCRegister Reg = getReg();
- switch (Reg.id()) {
- default:
- llvm_unreachable("Invalid tile register!");
- case X86::TMM0:
- case X86::TMM1:
- Reg = X86::TMM0_TMM1;
- break;
- case X86::TMM2:
- case X86::TMM3:
- Reg = X86::TMM2_TMM3;
- break;
- case X86::TMM4:
- case X86::TMM5:
- Reg = X86::TMM4_TMM5;
- break;
- case X86::TMM6:
- case X86::TMM7:
- Reg = X86::TMM6_TMM7;
- break;
- }
- Inst.addOperand(MCOperand::createReg(Reg));
- }
-
void addMemOperands(MCInst &Inst, unsigned N) const {
assert((N == 5) && "Invalid number of operands!");
if (getMemBaseReg())
diff --git a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp
index 4927b45..7d2b5eb 100644
--- a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp
+++ b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp
@@ -810,10 +810,6 @@ static int readModRM(struct InternalInstruction *insn) {
if (index > 7) \
*valid = 0; \
return prefix##_TMM0 + index; \
- case TYPE_TMM_PAIR: \
- if (index > 7) \
- *valid = 0; \
- return prefix##_TMM0_TMM1 + (index / 2); \
case TYPE_VK: \
index &= 0xf; \
if (index > 7) \
@@ -2323,7 +2319,6 @@ static bool translateRM(MCInst &mcInst, const OperandSpecifier &operand,
case TYPE_YMM:
case TYPE_ZMM:
case TYPE_TMM:
- case TYPE_TMM_PAIR:
case TYPE_VK_PAIR:
case TYPE_VK:
case TYPE_DEBUGREG:
diff --git a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h
index dc9af2c..b0aa70b 100644
--- a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h
+++ b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h
@@ -535,12 +535,6 @@ namespace X86Disassembler {
ENTRY(TMM6) \
ENTRY(TMM7)
-#define REGS_TMM_PAIRS \
- ENTRY(TMM0_TMM1) \
- ENTRY(TMM2_TMM3) \
- ENTRY(TMM4_TMM5) \
- ENTRY(TMM6_TMM7)
-
#define ALL_EA_BASES \
EA_BASES_16BIT \
EA_BASES_32BIT \
@@ -565,7 +559,6 @@ namespace X86Disassembler {
REGS_DEBUG \
REGS_CONTROL \
REGS_TMM \
- REGS_TMM_PAIRS \
ENTRY(RIP)
/// All possible values of the base field for effective-address
diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp
index 1c5f166..759d95e 100644
--- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp
+++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp
@@ -467,22 +467,3 @@ void X86InstPrinterCommon::printVKPair(const MCInst *MI, unsigned OpNo,
}
llvm_unreachable("Unknown mask pair register name");
}
-
-void X86InstPrinterCommon::printTILEPair(const MCInst *MI, unsigned OpNo,
- raw_ostream &OS) {
- switch (MI->getOperand(OpNo).getReg()) {
- case X86::TMM0_TMM1:
- printRegName(OS, X86::TMM0);
- return;
- case X86::TMM2_TMM3:
- printRegName(OS, X86::TMM2);
- return;
- case X86::TMM4_TMM5:
- printRegName(OS, X86::TMM4);
- return;
- case X86::TMM6_TMM7:
- printRegName(OS, X86::TMM6);
- return;
- }
- llvm_unreachable("Unknown mask pair register name");
-}
diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h
index 2c9467c..cb55f2f 100644
--- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h
+++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h
@@ -40,7 +40,6 @@ protected:
const MCSubtargetInfo &STI);
void printOptionalSegReg(const MCInst *MI, unsigned OpNo, raw_ostream &O);
void printVKPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS);
- void printTILEPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS);
};
} // end namespace llvm
diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td
index a1fd366..9e291a6 100644
--- a/llvm/lib/Target/X86/X86.td
+++ b/llvm/lib/Target/X86/X86.td
@@ -274,9 +274,6 @@ def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true",
def FeatureAMXMOVRS : SubtargetFeature<"amx-movrs", "HasAMXMOVRS", "true",
"Support AMX-MOVRS instructions",
[FeatureAMXTILE]>;
-def FeatureAMXTRANSPOSE : SubtargetFeature<"amx-transpose", "HasAMXTRANSPOSE", "true",
- "Support AMX amx-transpose instructions",
- [FeatureAMXTILE]>;
def FeatureAMXAVX512 : SubtargetFeature<"amx-avx512",
"HasAMXAVX512", "true",
"Support AMX-AVX512 instructions",
@@ -1177,8 +1174,7 @@ def ProcessorFeatures {
FeatureAMXMOVRS,
FeatureAMXAVX512,
FeatureAMXFP8,
- FeatureAMXTF32,
- FeatureAMXTRANSPOSE];
+ FeatureAMXTF32];
list<SubtargetFeature> DMRFeatures =
!listconcat(GNRDFeatures, DMRAdditionalFeatures);
diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
index 4a9b824..e3c44c0 100644
--- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -649,149 +649,6 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
MI.setDesc(TII->get(Opc));
return true;
}
- // TILEPAIRLOAD is just for TILEPair spill, we don't have corresponding
- // AMX instruction to support it. So, split it to 2 load instructions:
- // "TILEPAIRLOAD TMM0:TMM1, Base, Scale, Index, Offset, Segment" -->
- // "TILELOAD TMM0, Base, Scale, Index, Offset, Segment" +
- // "TILELOAD TMM1, Base, Scale, Index, Offset + TMM_SIZE, Segment"
- case X86::PTILEPAIRLOAD: {
- int64_t Disp = MBBI->getOperand(1 + X86::AddrDisp).getImm();
- Register TReg = MBBI->getOperand(0).getReg();
- bool DstIsDead = MBBI->getOperand(0).isDead();
- Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0);
- Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1);
- unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8;
-
- MachineInstrBuilder MIBLo =
- BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD))
- .addReg(TReg0, RegState::Define | getDeadRegState(DstIsDead));
- MachineInstrBuilder MIBHi =
- BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD))
- .addReg(TReg1, RegState::Define | getDeadRegState(DstIsDead));
-
- for (int i = 0; i < X86::AddrNumOperands; ++i) {
- MIBLo.add(MBBI->getOperand(1 + i));
- if (i == X86::AddrDisp)
- MIBHi.addImm(Disp + TmmSize);
- else
- MIBHi.add(MBBI->getOperand(1 + i));
- }
-
- // Make sure the first stride reg used in first tileload is alive.
- MachineOperand &Stride =
- MIBLo.getInstr()->getOperand(1 + X86::AddrIndexReg);
- Stride.setIsKill(false);
-
- // Split the memory operand, adjusting the offset and size for the halves.
- MachineMemOperand *OldMMO = MBBI->memoperands().front();
- MachineFunction *MF = MBB.getParent();
- MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize);
- MachineMemOperand *MMOHi =
- MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize);
-
- MIBLo.setMemRefs(MMOLo);
- MIBHi.setMemRefs(MMOHi);
-
- // Delete the pseudo.
- MBB.erase(MBBI);
- return true;
- }
- // Similar with TILEPAIRLOAD, TILEPAIRSTORE is just for TILEPair spill, no
- // corresponding AMX instruction to support it. So, split it too:
- // "TILEPAIRSTORE Base, Scale, Index, Offset, Segment, TMM0:TMM1" -->
- // "TILESTORE Base, Scale, Index, Offset, Segment, TMM0" +
- // "TILESTORE Base, Scale, Index, Offset + TMM_SIZE, Segment, TMM1"
- case X86::PTILEPAIRSTORE: {
- int64_t Disp = MBBI->getOperand(X86::AddrDisp).getImm();
- Register TReg = MBBI->getOperand(X86::AddrNumOperands).getReg();
- bool SrcIsKill = MBBI->getOperand(X86::AddrNumOperands).isKill();
- Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0);
- Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1);
- unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8;
-
- MachineInstrBuilder MIBLo =
- BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED));
- MachineInstrBuilder MIBHi =
- BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED));
-
- for (int i = 0; i < X86::AddrNumOperands; ++i) {
- MIBLo.add(MBBI->getOperand(i));
- if (i == X86::AddrDisp)
- MIBHi.addImm(Disp + TmmSize);
- else
- MIBHi.add(MBBI->getOperand(i));
- }
- MIBLo.addReg(TReg0, getKillRegState(SrcIsKill));
- MIBHi.addReg(TReg1, getKillRegState(SrcIsKill));
-
- // Make sure the first stride reg used in first tilestore is alive.
- MachineOperand &Stride = MIBLo.getInstr()->getOperand(X86::AddrIndexReg);
- Stride.setIsKill(false);
-
- // Split the memory operand, adjusting the offset and size for the halves.
- MachineMemOperand *OldMMO = MBBI->memoperands().front();
- MachineFunction *MF = MBB.getParent();
- MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize);
- MachineMemOperand *MMOHi =
- MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize);
-
- MIBLo.setMemRefs(MMOLo);
- MIBHi.setMemRefs(MMOHi);
-
- // Delete the pseudo.
- MBB.erase(MBBI);
- return true;
- }
- case X86::PT2RPNTLVWZ0V:
- case X86::PT2RPNTLVWZ0T1V:
- case X86::PT2RPNTLVWZ1V:
- case X86::PT2RPNTLVWZ1T1V:
- case X86::PT2RPNTLVWZ0RSV:
- case X86::PT2RPNTLVWZ0RST1V:
- case X86::PT2RPNTLVWZ1RSV:
- case X86::PT2RPNTLVWZ1RST1V: {
- for (unsigned i = 3; i > 0; --i)
- MI.removeOperand(i);
- unsigned Opc;
- switch (Opcode) {
- case X86::PT2RPNTLVWZ0V:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0);
- break;
- case X86::PT2RPNTLVWZ0T1V:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1);
- break;
- case X86::PT2RPNTLVWZ1V:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1);
- break;
- case X86::PT2RPNTLVWZ1T1V:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1);
- break;
- case X86::PT2RPNTLVWZ0RSV:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS);
- break;
- case X86::PT2RPNTLVWZ0RST1V:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1);
- break;
- case X86::PT2RPNTLVWZ1RSV:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS);
- break;
- case X86::PT2RPNTLVWZ1RST1V:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1);
- break;
- default:
- llvm_unreachable("Impossible Opcode!");
- }
- MI.setDesc(TII->get(Opc));
- return true;
- }
- case X86::PTTRANSPOSEDV:
- case X86::PTCONJTFP16V: {
- for (int i = 2; i > 0; --i)
- MI.removeOperand(i);
- MI.setDesc(TII->get(Opcode == X86::PTTRANSPOSEDV ? X86::TTRANSPOSED
- : X86::TCONJTFP16));
- return true;
- }
case X86::PTCMMIMFP16PSV:
case X86::PTCMMRLFP16PSV:
case X86::PTDPBSSDV:
@@ -800,13 +657,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
case X86::PTDPBUUDV:
case X86::PTDPBF16PSV:
case X86::PTDPFP16PSV:
- case X86::PTTDPBF16PSV:
- case X86::PTTDPFP16PSV:
- case X86::PTTCMMIMFP16PSV:
- case X86::PTTCMMRLFP16PSV:
- case X86::PTCONJTCMMIMFP16PSV:
case X86::PTMMULTF32PSV:
- case X86::PTTMMULTF32PSV:
case X86::PTDPBF8PSV:
case X86::PTDPBHF8PSV:
case X86::PTDPHBF8PSV:
@@ -816,6 +667,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
MI.removeOperand(i);
unsigned Opc;
switch (Opcode) {
+ // clang-format off
case X86::PTCMMIMFP16PSV: Opc = X86::TCMMIMFP16PS; break;
case X86::PTCMMRLFP16PSV: Opc = X86::TCMMRLFP16PS; break;
case X86::PTDPBSSDV: Opc = X86::TDPBSSD; break;
@@ -824,40 +676,12 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break;
case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break;
case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break;
- case X86::PTTDPBF16PSV:
- Opc = X86::TTDPBF16PS;
- break;
- case X86::PTTDPFP16PSV:
- Opc = X86::TTDPFP16PS;
- break;
- case X86::PTTCMMIMFP16PSV:
- Opc = X86::TTCMMIMFP16PS;
- break;
- case X86::PTTCMMRLFP16PSV:
- Opc = X86::TTCMMRLFP16PS;
- break;
- case X86::PTCONJTCMMIMFP16PSV:
- Opc = X86::TCONJTCMMIMFP16PS;
- break;
- case X86::PTMMULTF32PSV:
- Opc = X86::TMMULTF32PS;
- break;
- case X86::PTTMMULTF32PSV:
- Opc = X86::TTMMULTF32PS;
- break;
- case X86::PTDPBF8PSV:
- Opc = X86::TDPBF8PS;
- break;
- case X86::PTDPBHF8PSV:
- Opc = X86::TDPBHF8PS;
- break;
- case X86::PTDPHBF8PSV:
- Opc = X86::TDPHBF8PS;
- break;
- case X86::PTDPHF8PSV:
- Opc = X86::TDPHF8PS;
- break;
-
+ case X86::PTMMULTF32PSV: Opc = X86::TMMULTF32PS; break;
+ case X86::PTDPBF8PSV: Opc = X86::TDPBF8PS; break;
+ case X86::PTDPBHF8PSV: Opc = X86::TDPBHF8PS; break;
+ case X86::PTDPHBF8PSV: Opc = X86::TDPHBF8PS; break;
+ case X86::PTDPHF8PSV: Opc = X86::TDPHF8PS; break;
+ // clang-format on
default:
llvm_unreachable("Unexpected Opcode");
}
diff --git a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp
index 787b71d..06f729a 100644
--- a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp
+++ b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp
@@ -267,24 +267,16 @@ void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI,
<< printReg(TileReg, TRI) << '\n');
}
-static unsigned getTileDefNum(MachineRegisterInfo *MRI, Register Reg) {
- if (Reg.isVirtual()) {
- unsigned RegClassID = MRI->getRegClass(Reg)->getID();
- if (RegClassID == X86::TILERegClassID)
- return 1;
- if (RegClassID == X86::TILEPAIRRegClassID)
- return 2;
- } else {
- if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
- return 1;
- if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
- return 2;
+static bool isTileRegister(MachineRegisterInfo *MRI, Register Reg) {
+ if (Reg.isVirtual() &&
+ (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)) {
+ return true;
}
- return 0;
-}
-static bool isTileRegister(MachineRegisterInfo *MRI, Register VirtReg) {
- return getTileDefNum(MRI, VirtReg) > 0;
+ if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
+ return true;
+
+ return false;
}
static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
@@ -296,7 +288,7 @@ static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
if (!MO.isReg())
return false;
- return getTileDefNum(MRI, MO.getReg()) > 0;
+ return isTileRegister(MRI, MO.getReg());
}
static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) {
@@ -636,19 +628,7 @@ bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) {
else if (dominates(MBB, LastShapeMI, ColMI))
LastShapeMI = ColMI;
}
- unsigned TileDefNum = getTileDefNum(MRI, MI.getOperand(0).getReg());
- if (TileDefNum > 1) {
- for (unsigned I = 1; I < TileDefNum; I++) {
- MachineOperand *ColxMO = &MI.getOperand(2 + I);
- MachineInstr *ColxMI = MRI->getVRegDef(ColxMO->getReg());
- if (ColxMI->getParent() == &MBB) {
- if (!LastShapeMI)
- LastShapeMI = ColxMI;
- else if (dominates(MBB, LastShapeMI, ColxMI))
- LastShapeMI = ColxMI;
- }
- }
- }
+
// If there is user live out of the tilecfg, spill it and reload in
// before the user.
Register TileReg = MI.getOperand(0).getReg();
diff --git a/llvm/lib/Target/X86/X86FastTileConfig.cpp b/llvm/lib/Target/X86/X86FastTileConfig.cpp
index 11d331b..d86ae36 100644
--- a/llvm/lib/Target/X86/X86FastTileConfig.cpp
+++ b/llvm/lib/Target/X86/X86FastTileConfig.cpp
@@ -77,14 +77,14 @@ INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,
INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,
"Fast Tile Register Configure", false, false)
-static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) {
+static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
// There is no phi instruction after register allocation.
assert(MI.isPHI() == false);
// The instruction must have 3 operands: tile def, row, col.
// It should be AMX pseudo instruction that have shape operand.
if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 ||
!MI.isPseudo())
- return 0;
+ return false;
MachineOperand &MO = MI.getOperand(0);
if (MO.isReg()) {
@@ -93,24 +93,18 @@ static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) {
// register is not rewritten yet.
if (Reg.isVirtual()) {
if (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)
- return 1;
- if (MRI->getRegClass(Reg)->getID() == X86::TILEPAIRRegClassID)
- return 2;
+ return true;
}
if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
- return 1;
- if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
- return 2;
+ return true;
}
- return 0;
+ return false;
}
static unsigned getTMMIndex(Register Reg) {
if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
return Reg - X86::TMM0;
- if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
- return (Reg - X86::TMM0_TMM1) * 2;
llvm_unreachable("Invalid Tmm Reg!");
}
@@ -120,17 +114,14 @@ bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) {
bool Change = false;
SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos;
for (MachineInstr &MI : reverse(MBB)) {
- unsigned DefNum = getNumDefTiles(MRI, MI);
- if (DefNum == 0 && MI.getOpcode() != X86::PLDTILECFGV)
+ if (!isTileDef(MRI, MI) && MI.getOpcode() != X86::PLDTILECFGV)
continue;
// AMX instructions that define tile register.
if (MI.getOpcode() != X86::PLDTILECFGV) {
MachineOperand &Row = MI.getOperand(1);
unsigned TMMIdx = getTMMIndex(MI.getOperand(0).getReg());
- for (unsigned I = 0; I < DefNum; I++) {
- MachineOperand &Col = MI.getOperand(2 + I);
- ShapeInfos.push_back({TMMIdx + I, ShapeT(&Row, &Col)});
- }
+ MachineOperand &Col = MI.getOperand(2);
+ ShapeInfos.push_back({TMMIdx, ShapeT(&Row, &Col)});
} else { // PLDTILECFGV
// Rewrite the shape information to memory. Stack slot should have
// been initialized to zero in pre config.
diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index 4393f6e..d4418c8 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -337,23 +337,8 @@ namespace {
// lowering but before ISEL.
bool isAMXSDNode(SDNode *N) const {
// Check if N is AMX SDNode:
- // 1. check specific opcode since these carry MVT::Untyped instead of
- // x86amx_type;
- // 2. check result type;
- // 3. check operand type;
- switch (N->getOpcode()) {
- default:
- break;
- case X86::PT2RPNTLVWZ0V:
- case X86::PT2RPNTLVWZ0T1V:
- case X86::PT2RPNTLVWZ1V:
- case X86::PT2RPNTLVWZ1T1V:
- case X86::PT2RPNTLVWZ0RSV:
- case X86::PT2RPNTLVWZ0RST1V:
- case X86::PT2RPNTLVWZ1RSV:
- case X86::PT2RPNTLVWZ1RST1V:
- return true;
- }
+ // 1. check result type;
+ // 2. check operand type;
for (unsigned Idx = 0, E = N->getNumValues(); Idx != E; ++Idx) {
if (N->getValueType(Idx) == MVT::x86amx)
return true;
@@ -5398,65 +5383,6 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
ReplaceNode(Node, CNode);
return;
}
- case Intrinsic::x86_t2rpntlvwz0rs:
- case Intrinsic::x86_t2rpntlvwz0rst1:
- case Intrinsic::x86_t2rpntlvwz1rs:
- case Intrinsic::x86_t2rpntlvwz1rst1:
- if (!Subtarget->hasAMXMOVRS())
- break;
- [[fallthrough]];
- case Intrinsic::x86_t2rpntlvwz0:
- case Intrinsic::x86_t2rpntlvwz0t1:
- case Intrinsic::x86_t2rpntlvwz1:
- case Intrinsic::x86_t2rpntlvwz1t1: {
- if (!Subtarget->hasAMXTRANSPOSE())
- break;
- auto *MFI =
- CurDAG->getMachineFunction().getInfo<X86MachineFunctionInfo>();
- MFI->setAMXProgModel(AMXProgModelEnum::DirectReg);
- unsigned Opc;
- switch (IntNo) {
- default:
- llvm_unreachable("Unexpected intrinsic!");
- case Intrinsic::x86_t2rpntlvwz0:
- Opc = X86::PT2RPNTLVWZ0;
- break;
- case Intrinsic::x86_t2rpntlvwz0t1:
- Opc = X86::PT2RPNTLVWZ0T1;
- break;
- case Intrinsic::x86_t2rpntlvwz1:
- Opc = X86::PT2RPNTLVWZ1;
- break;
- case Intrinsic::x86_t2rpntlvwz1t1:
- Opc = X86::PT2RPNTLVWZ1T1;
- break;
- case Intrinsic::x86_t2rpntlvwz0rs:
- Opc = X86::PT2RPNTLVWZ0RS;
- break;
- case Intrinsic::x86_t2rpntlvwz0rst1:
- Opc = X86::PT2RPNTLVWZ0RST1;
- break;
- case Intrinsic::x86_t2rpntlvwz1rs:
- Opc = X86::PT2RPNTLVWZ1RS;
- break;
- case Intrinsic::x86_t2rpntlvwz1rst1:
- Opc = X86::PT2RPNTLVWZ1RST1;
- break;
- }
- // FIXME: Match displacement and scale.
- unsigned TIndex = Node->getConstantOperandVal(2);
- SDValue TReg = getI8Imm(TIndex, dl);
- SDValue Base = Node->getOperand(3);
- SDValue Scale = getI8Imm(1, dl);
- SDValue Index = Node->getOperand(4);
- SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32);
- SDValue Segment = CurDAG->getRegister(0, MVT::i16);
- SDValue Chain = Node->getOperand(0);
- SDValue Ops[] = {TReg, Base, Scale, Index, Disp, Segment, Chain};
- MachineSDNode *CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops);
- ReplaceNode(Node, CNode);
- return;
- }
}
break;
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 49beada..007074c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -27946,67 +27946,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC,
Operation.getValue(1));
}
- case Intrinsic::x86_t2rpntlvwz0rs_internal:
- case Intrinsic::x86_t2rpntlvwz0rst1_internal:
- case Intrinsic::x86_t2rpntlvwz1rs_internal:
- case Intrinsic::x86_t2rpntlvwz1rst1_internal:
- case Intrinsic::x86_t2rpntlvwz0_internal:
- case Intrinsic::x86_t2rpntlvwz0t1_internal:
- case Intrinsic::x86_t2rpntlvwz1_internal:
- case Intrinsic::x86_t2rpntlvwz1t1_internal: {
- auto *X86MFI = DAG.getMachineFunction().getInfo<X86MachineFunctionInfo>();
- X86MFI->setAMXProgModel(AMXProgModelEnum::ManagedRA);
- unsigned IntNo = Op.getConstantOperandVal(1);
- unsigned Opc = 0;
- switch (IntNo) {
- default:
- llvm_unreachable("Unexpected intrinsic!");
- case Intrinsic::x86_t2rpntlvwz0_internal:
- Opc = X86::PT2RPNTLVWZ0V;
- break;
- case Intrinsic::x86_t2rpntlvwz0t1_internal:
- Opc = X86::PT2RPNTLVWZ0T1V;
- break;
- case Intrinsic::x86_t2rpntlvwz1_internal:
- Opc = X86::PT2RPNTLVWZ1V;
- break;
- case Intrinsic::x86_t2rpntlvwz1t1_internal:
- Opc = X86::PT2RPNTLVWZ1T1V;
- break;
- case Intrinsic::x86_t2rpntlvwz0rs_internal:
- Opc = X86::PT2RPNTLVWZ0RSV;
- break;
- case Intrinsic::x86_t2rpntlvwz0rst1_internal:
- Opc = X86::PT2RPNTLVWZ0RST1V;
- break;
- case Intrinsic::x86_t2rpntlvwz1rs_internal:
- Opc = X86::PT2RPNTLVWZ1RSV;
- break;
- case Intrinsic::x86_t2rpntlvwz1rst1_internal:
- Opc = X86::PT2RPNTLVWZ1RST1V;
- break;
- }
-
- SDLoc DL(Op);
- SDVTList VTs = DAG.getVTList(MVT::Untyped, MVT::Other);
-
- SDValue Ops[] = {Op.getOperand(2), // Row
- Op.getOperand(3), // Col0
- Op.getOperand(4), // Col1
- Op.getOperand(5), // Base
- DAG.getTargetConstant(1, DL, MVT::i8), // Scale
- Op.getOperand(6), // Index
- DAG.getTargetConstant(0, DL, MVT::i32), // Disp
- DAG.getRegister(0, MVT::i16), // Segment
- Op.getOperand(0)}; // Chain
-
- MachineSDNode *Res = DAG.getMachineNode(Opc, DL, VTs, Ops);
- SDValue Res0 = DAG.getTargetExtractSubreg(X86::sub_t0, DL, MVT::x86amx,
- SDValue(Res, 0));
- SDValue Res1 = DAG.getTargetExtractSubreg(X86::sub_t1, DL, MVT::x86amx,
- SDValue(Res, 0));
- return DAG.getMergeValues({Res0, Res1, SDValue(Res, 1)}, DL);
- }
case Intrinsic::x86_atomic_bts_rm:
case Intrinsic::x86_atomic_btc_rm:
case Intrinsic::x86_atomic_btr_rm: {
@@ -37745,10 +37684,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
assert (Imm < 8 && "Illegal tmm index");
return X86::TMM0 + Imm;
};
- auto TMMImmToTMMPair = [](unsigned Imm) {
- assert(Imm < 8 && "Illegal tmm pair index.");
- return X86::TMM0_TMM1 + Imm / 2;
- };
switch (MI.getOpcode()) {
default:
llvm_unreachable("Unexpected instr type to insert");
@@ -38129,53 +38064,25 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
case X86::PTDPBHF8PS:
case X86::PTDPHBF8PS:
case X86::PTDPHF8PS:
- case X86::PTTDPBF16PS:
- case X86::PTTDPFP16PS:
- case X86::PTTCMMIMFP16PS:
- case X86::PTTCMMRLFP16PS:
- case X86::PTCONJTCMMIMFP16PS:
- case X86::PTMMULTF32PS:
- case X86::PTTMMULTF32PS: {
+ case X86::PTMMULTF32PS: {
unsigned Opc;
switch (MI.getOpcode()) {
default: llvm_unreachable("illegal opcode!");
+ // clang-format off
case X86::PTDPBSSD: Opc = X86::TDPBSSD; break;
case X86::PTDPBSUD: Opc = X86::TDPBSUD; break;
case X86::PTDPBUSD: Opc = X86::TDPBUSD; break;
case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;
case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;
case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break;
- case X86::PTCMMIMFP16PS:
- Opc = X86::TCMMIMFP16PS;
- break;
- case X86::PTCMMRLFP16PS:
- Opc = X86::TCMMRLFP16PS;
- break;
+ case X86::PTCMMIMFP16PS: Opc = X86::TCMMIMFP16PS; break;
+ case X86::PTCMMRLFP16PS: Opc = X86::TCMMRLFP16PS; break;
case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;
case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;
case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;
case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break;
- case X86::PTTDPBF16PS:
- Opc = X86::TTDPBF16PS;
- break;
- case X86::PTTDPFP16PS:
- Opc = X86::TTDPFP16PS;
- break;
- case X86::PTTCMMIMFP16PS:
- Opc = X86::TTCMMIMFP16PS;
- break;
- case X86::PTTCMMRLFP16PS:
- Opc = X86::TTCMMRLFP16PS;
- break;
- case X86::PTCONJTCMMIMFP16PS:
- Opc = X86::TCONJTCMMIMFP16PS;
- break;
- case X86::PTMMULTF32PS:
- Opc = X86::TMMULTF32PS;
- break;
- case X86::PTTMMULTF32PS:
- Opc = X86::TTMMULTF32PS;
- break;
+ case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break;
+ // clang-format on
}
MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc));
@@ -38246,70 +38153,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
MI.eraseFromParent(); // The pseudo is gone now.
return BB;
}
- case X86::PT2RPNTLVWZ0:
- case X86::PT2RPNTLVWZ0T1:
- case X86::PT2RPNTLVWZ1:
- case X86::PT2RPNTLVWZ1T1:
- case X86::PT2RPNTLVWZ0RS:
- case X86::PT2RPNTLVWZ0RST1:
- case X86::PT2RPNTLVWZ1RS:
- case X86::PT2RPNTLVWZ1RST1: {
- const DebugLoc &DL = MI.getDebugLoc();
- unsigned Opc;
-#define GET_EGPR_IF_ENABLED(OPC) (Subtarget.hasEGPR() ? OPC##_EVEX : OPC)
- switch (MI.getOpcode()) {
- default:
- llvm_unreachable("Unexpected instruction!");
- case X86::PT2RPNTLVWZ0:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0);
- break;
- case X86::PT2RPNTLVWZ0T1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1);
- break;
- case X86::PT2RPNTLVWZ1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1);
- break;
- case X86::PT2RPNTLVWZ1T1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1);
- break;
- case X86::PT2RPNTLVWZ0RS:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS);
- break;
- case X86::PT2RPNTLVWZ0RST1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1);
- break;
- case X86::PT2RPNTLVWZ1RS:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS);
- break;
- case X86::PT2RPNTLVWZ1RST1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1);
- break;
- }
-#undef GET_EGPR_IF_ENABLED
- MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc));
- MIB.addReg(TMMImmToTMMPair(MI.getOperand(0).getImm()), RegState::Define);
-
- MIB.add(MI.getOperand(1)); // base
- MIB.add(MI.getOperand(2)); // scale
- MIB.add(MI.getOperand(3)); // index
- MIB.add(MI.getOperand(4)); // displacement
- MIB.add(MI.getOperand(5)); // segment
- MI.eraseFromParent(); // The pseudo is gone now.
- return BB;
- }
- case X86::PTTRANSPOSED:
- case X86::PTCONJTFP16: {
- const DebugLoc &DL = MI.getDebugLoc();
- unsigned Opc = MI.getOpcode() == X86::PTTRANSPOSED ? X86::TTRANSPOSED
- : X86::TCONJTFP16;
-
- MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc));
- MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define);
- MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef);
-
- MI.eraseFromParent(); // The pseudo is gone now.
- return BB;
- }
case X86::PTCVTROWPS2BF16Hrri:
case X86::PTCVTROWPS2BF16Lrri:
case X86::PTCVTROWPS2PHHrri:
@@ -53502,7 +53345,8 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
}
// Look for a RMW operation that only touches one bit of a larger than legal
-// type and fold it to a BTC/BTR/BTS pattern acting on a single i32 sub value.
+// type and fold it to a BTC/BTR/BTS or bit insertion pattern acting on a single
+// i32 sub value.
static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
@@ -53528,14 +53372,20 @@ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,
// BTR: X & ~(1 << ShAmt)
// BTS: X | (1 << ShAmt)
// BTC: X ^ (1 << ShAmt)
- SDValue ShAmt;
+ //
+ // BitInsert: (X & ~(1 << ShAmt)) | (InsertBit << ShAmt)
+ SDValue InsertBit, ShAmt;
if (!StoredVal.hasOneUse() ||
!(sd_match(StoredVal, m_And(m_Specific(LoadVal),
m_Not(m_Shl(m_One(), m_Value(ShAmt))))) ||
sd_match(StoredVal,
m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) ||
sd_match(StoredVal,
- m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt))))))
+ m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) ||
+ sd_match(StoredVal,
+ m_Or(m_And(m_Specific(LoadVal),
+ m_Not(m_Shl(m_One(), m_Value(ShAmt)))),
+ m_Shl(m_Value(InsertBit), m_Deferred(ShAmt))))))
return SDValue();
// Ensure the shift amount is in bounds.
@@ -53543,6 +53393,13 @@ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,
if (KnownAmt.getMaxValue().uge(VT.getSizeInBits()))
return SDValue();
+ // If we're inserting a bit then it must be the LSB.
+ if (InsertBit) {
+ KnownBits KnownInsert = DAG.computeKnownBits(InsertBit);
+ if (KnownInsert.countMinLeadingZeros() < (VT.getSizeInBits() - 1))
+ return SDValue();
+ }
+
// Split the shift into an alignment shift that moves the active i32 block to
// the bottom bits for truncation and a modulo shift that can act on the i32.
EVT AmtVT = ShAmt.getValueType();
@@ -53550,6 +53407,7 @@ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,
DAG.getSignedConstant(-32LL, DL, AmtVT));
SDValue ModuloAmt =
DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT));
+ ModuloAmt = DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8);
// Compute the byte offset for the i32 block that is changed by the RMW.
// combineTruncate will adjust the load for us in a similar way.
@@ -53564,13 +53422,23 @@ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,
SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt);
X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X);
- SDValue Mask =
- DAG.getNode(ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32),
- DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8));
- if (StoredVal.getOpcode() == ISD::AND)
- Mask = DAG.getNOT(DL, Mask, MVT::i32);
+ SDValue Mask = DAG.getNode(ISD::SHL, DL, MVT::i32,
+ DAG.getConstant(1, DL, MVT::i32), ModuloAmt);
+
+ SDValue Res;
+ if (InsertBit) {
+ SDValue BitMask =
+ DAG.getNode(ISD::SHL, DL, MVT::i32,
+ DAG.getZExtOrTrunc(InsertBit, DL, MVT::i32), ModuloAmt);
+ Res =
+ DAG.getNode(ISD::AND, DL, MVT::i32, X, DAG.getNOT(DL, Mask, MVT::i32));
+ Res = DAG.getNode(ISD::OR, DL, MVT::i32, Res, BitMask);
+ } else {
+ if (StoredVal.getOpcode() == ISD::AND)
+ Mask = DAG.getNOT(DL, Mask, MVT::i32);
+ Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask);
+ }
- SDValue Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask);
return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(),
Align(), St->getMemOperand()->getFlags());
}
@@ -54591,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
const X86Subtarget &Subtarget,
const SDLoc &DL) {
+ using namespace SDPatternMatch;
if (!VT.isVector() || !Subtarget.hasSSSE3())
return SDValue();
@@ -54600,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
return SDValue();
SDValue SSatVal = detectSSatPattern(In, VT);
- if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
+ if (!SSatVal)
return SDValue();
- // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
- // of multiplies from even/odd elements.
- SDValue N0 = SSatVal.getOperand(0);
- SDValue N1 = SSatVal.getOperand(1);
-
- if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
- return SDValue();
-
- SDValue N00 = N0.getOperand(0);
- SDValue N01 = N0.getOperand(1);
- SDValue N10 = N1.getOperand(0);
- SDValue N11 = N1.getOperand(1);
-
+ // See if this is a signed saturation of an ADD, adding pairs of multiplies
+ // from even/odd elements, from zero_extend/sign_extend operands.
+ //
// TODO: Handle constant vectors and use knownbits/computenumsignbits?
- // Canonicalize zero_extend to LHS.
- if (N01.getOpcode() == ISD::ZERO_EXTEND)
- std::swap(N00, N01);
- if (N11.getOpcode() == ISD::ZERO_EXTEND)
- std::swap(N10, N11);
-
- // Ensure we have a zero_extend and a sign_extend.
- if (N00.getOpcode() != ISD::ZERO_EXTEND ||
- N01.getOpcode() != ISD::SIGN_EXTEND ||
- N10.getOpcode() != ISD::ZERO_EXTEND ||
- N11.getOpcode() != ISD::SIGN_EXTEND)
+ SDValue N00, N01, N10, N11;
+ if (!sd_match(SSatVal,
+ m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))),
+ m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11))))))
return SDValue();
- // Peek through the extends.
- N00 = N00.getOperand(0);
- N01 = N01.getOperand(0);
- N10 = N10.getOperand(0);
- N11 = N11.getOperand(0);
-
// Ensure the extend is from vXi8.
if (N00.getValueType().getVectorElementType() != MVT::i8 ||
N01.getValueType().getVectorElementType() != MVT::i8 ||
@@ -54768,9 +54614,11 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
// Check the shift amount is byte aligned.
// Check the truncation doesn't use any shifted in (zero) top bits.
+ // Check the shift amount doesn't depend on the original load.
if (KnownAmt.countMinTrailingZeros() >= 3 &&
KnownAmt.getMaxValue().ule(SrcVT.getSizeInBits() -
- VT.getSizeInBits())) {
+ VT.getSizeInBits()) &&
+ !Ld->isPredecessorOf(ShAmt.getNode())) {
EVT PtrVT = Ld->getBasePtr().getValueType();
SDValue PtrBitOfs = DAG.getZExtOrTrunc(ShAmt, DL, PtrVT);
SDValue PtrByteOfs =
diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td
index 69a5115..522782a 100644
--- a/llvm/lib/Target/X86/X86InstrAMX.td
+++ b/llvm/lib/Target/X86/X86InstrAMX.td
@@ -338,188 +338,6 @@ let Predicates = [HasAMXFP8, In64BitMode] in {
}
}
-let Predicates = [HasAMXTILE, In64BitMode], isPseudo = true, SchedRW = [WriteSystem] in {
- let mayStore = 1 in
- def PTILEPAIRSTORE : PseudoI<(outs), (ins opaquemem:$src1, TILEPair:$src2), []>;
- let mayLoad = 1 in
- def PTILEPAIRLOAD : PseudoI<(outs TILEPair:$dst), (ins opaquemem:$src), []>;
-}
-
-multiclass T2RPNTLVW_Base<bits<8> op1, bits<8> op2, string rs, string suffix> {
- def Z0#rs#suffix : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src),
- "t2rpntlvwz0" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PS;
- def Z0#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src),
- "t2rpntlvwz0" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PS;
- def Z1#rs#suffix : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src),
- "t2rpntlvwz1" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PD;
- def Z1#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src),
- "t2rpntlvwz1" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PD;
-}
-
-let Predicates = [HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in
- defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "">, T8, VEX;
-
-let Predicates = [HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in
- defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "_EVEX">, T8, EVEX, NoCD8;
-
-let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in
- defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "">, T_MAP5, VEX;
-
-let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in
- defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "_EVEX">, T_MAP5, EVEX, NoCD8;
-
-let Predicates = [HasAMXTRANSPOSE, In64BitMode] in {
- let SchedRW = [WriteSystem] in {
- def TTRANSPOSED : I<0x5f, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src),
- "ttransposed\t{$src, $dst|$dst, $src}", []>, VEX, T8, XS;
- let isPseudo = true in {
- def PT2RPNTLVWZ0V : PseudoI<(outs TILEPair:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4),
- []>;
- def PT2RPNTLVWZ0T1V : PseudoI<(outs TILEPair:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4),
- []>;
- def PT2RPNTLVWZ1V : PseudoI<(outs TILEPair:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4),
- []>;
- def PT2RPNTLVWZ1T1V : PseudoI<(outs TILEPair:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4),
- []>;
- }
-
- def PTTRANSPOSEDV : PseudoI<(outs TILE:$dst),
- (ins GR16:$src1, GR16:$src2, TILE:$src),
- [(set TILE: $dst,
- (int_x86_ttransposed_internal GR16:$src1, GR16:$src2,
- TILE:$src))]>;
-
- let usesCustomInserter = 1 in {
- def PT2RPNTLVWZ0 : PseudoI<(outs), (ins u8imm:$dst,
- sibmem:$src1), []>;
- def PT2RPNTLVWZ0T1 : PseudoI<(outs), (ins u8imm:$dst,
- sibmem:$src1), []>;
- def PT2RPNTLVWZ1 : PseudoI<(outs), (ins u8imm:$dst,
- sibmem:$src1), []>;
- def PT2RPNTLVWZ1T1 : PseudoI<(outs), (ins u8imm:$dst,
- sibmem:$src1), []>;
- def PTTRANSPOSED : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src),
- [(int_x86_ttransposed timm:$dst, timm:$src)]>;
- }
- }
-} // HasAMXTILE, HasAMXTRANSPOSE
-
-let Predicates = [HasAMXBF16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in {
- let Constraints = "$src1 = $dst" in
- def TTDPBF16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst),
- (ins TILE:$src1, TILE:$src2, TILE:$src3),
- "ttdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}",
- []>, VEX, VVVV, T8,XS;
- let Constraints = "$src4 = $dst" in
- def PTTDPBF16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1,
- GR16:$src2, GR16:$src3, TILE:$src4,
- TILE:$src5, TILE:$src6),
- [(set TILE: $dst,
- (int_x86_ttdpbf16ps_internal GR16:$src1, GR16:$src2,
- GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>;
- let usesCustomInserter = 1 in
- def PTTDPBF16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
- [(int_x86_ttdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>;
-}
-
-let Predicates = [HasAMXFP16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in {
- let Constraints = "$src1 = $dst" in
- def TTDPFP16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst),
- (ins TILE:$src1, TILE:$src2, TILE:$src3),
- "ttdpfp16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}",
- []>, VEX, VVVV, T8,XD;
- let Constraints = "$src4 = $dst" in
- def PTTDPFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1,
- GR16:$src2, GR16:$src3, TILE:$src4,
- TILE:$src5, TILE:$src6),
- [(set TILE: $dst,
- (int_x86_ttdpfp16ps_internal GR16:$src1, GR16:$src2,
- GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>;
- let usesCustomInserter = 1 in
- def PTTDPFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
- [(int_x86_ttdpfp16ps timm:$src1, timm:$src2, timm:$src3)]>;
-}
-
-let Predicates = [HasAMXCOMPLEX, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in {
- let Constraints = "$src1 = $dst" in {
- def TTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst),
- (ins TILE:$src1, TILE:$src2, TILE:$src3),
- "ttcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}",
- []>, VEX, VVVV, T8,XD;
- def TTCMMRLFP16PS: I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst),
- (ins TILE:$src1, TILE:$src2, TILE:$src3),
- "ttcmmrlfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}",
- []>, VEX, VVVV, T8,XS;
- def TCONJTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst),
- (ins TILE:$src1, TILE:$src2, TILE:$src3),
- "tconjtcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}",
- []>, VEX, VVVV, WIG, T8,PS;
- }
- def TCONJTFP16 : I<0x6b, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src),
- "tconjtfp16\t{$src, $dst|$dst, $src}", []>, VEX, T8,PD;
-
- let Constraints = "$src4 = $dst" in {
- def PTTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1,
- GR16:$src2, GR16:$src3, TILE:$src4,
- TILE:$src5, TILE:$src6),
- [(set TILE: $dst,
- (int_x86_ttcmmimfp16ps_internal GR16:$src1, GR16:$src2,
- GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>;
- def PTTCMMRLFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1,
- GR16:$src2, GR16:$src3, TILE:$src4,
- TILE:$src5, TILE:$src6),
- [(set TILE: $dst,
- (int_x86_ttcmmrlfp16ps_internal GR16:$src1, GR16:$src2,
- GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>;
- def PTCONJTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1,
- GR16:$src2, GR16:$src3, TILE:$src4,
- TILE:$src5, TILE:$src6),
- [(set TILE: $dst,
- (int_x86_tconjtcmmimfp16ps_internal GR16:$src1, GR16:$src2,
- GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>;
- }
- def PTCONJTFP16V : PseudoI<(outs TILE:$dst), (ins GR16:$src1, GR16:$src2, TILE:$src3),
- [(set TILE: $dst, (int_x86_tconjtfp16_internal GR16:$src1, GR16:$src2, TILE:$src3))]>;
-
- let usesCustomInserter = 1 in {
- def PTTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
- [(int_x86_ttcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>;
- def PTTCMMRLFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
- [(int_x86_ttcmmrlfp16ps timm:$src1, timm:$src2, timm:$src3)]>;
- def PTCONJTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
- [(int_x86_tconjtcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>;
- def PTCONJTFP16 : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src),
- [(int_x86_tconjtfp16 timm:$dst, timm:$src)]>;
- }
-}
-
-let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in {
- let isPseudo = true in {
- def PT2RPNTLVWZ0RSV : PseudoI<(outs TILEPair:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4),
- []>;
- def PT2RPNTLVWZ0RST1V : PseudoI<(outs TILEPair:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4),
- []>;
- def PT2RPNTLVWZ1RSV : PseudoI<(outs TILEPair:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4),
- []>;
- def PT2RPNTLVWZ1RST1V : PseudoI<(outs TILEPair:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4),
- []>;
- }
- let usesCustomInserter = 1 in {
- def PT2RPNTLVWZ0RS : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>;
- def PT2RPNTLVWZ0RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>;
- def PT2RPNTLVWZ1RS : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>;
- def PT2RPNTLVWZ1RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>;
- }
-} // HasAMXMOVRS, HasAMXTRANSPOSE
-
multiclass TILELOADDRS_Base<string suffix> {
def suffix : I<0x4a, MRMSrcMemFSIB, (outs TILE:$dst), (ins sibmem:$src1),
"tileloaddrs\t{$src1, $dst|$dst, $src1}", []>, T8, XD;
@@ -721,29 +539,3 @@ let Predicates = [HasAMXTF32, In64BitMode] in {
}
} // SchedRW = [WriteSystem]
} // HasAMXTF32
-
-let Predicates = [HasAMXTF32, HasAMXTRANSPOSE, In64BitMode] in {
- let SchedRW = [WriteSystem] in {
- let Constraints = "$src1 = $dst" in {
- def TTMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst),
- (ins TILE:$src1, TILE:$src2, TILE:$src3),
- "ttmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}",
- []>, VEX, VVVV, T8, PS;
- }
- let Constraints = "$src4 = $dst" in {
- def PTTMMULTF32PSV : PseudoI<(outs TILE:$dst),
- (ins GR16:$src1, GR16:$src2, GR16:$src3,
- TILE:$src4, TILE:$src5, TILE:$src6),
- [(set TILE:$dst,
- (int_x86_ttmmultf32ps_internal GR16:$src1,
- GR16:$src2, GR16:$src3, TILE:$src4,
- TILE:$src5, TILE:$src6))]>;
- }
- let usesCustomInserter = 1 in {
- def PTTMMULTF32PS : PseudoI<(outs),
- (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
- [(int_x86_ttmmultf32ps timm:$src1, timm:$src2,
- timm:$src3)]>;
- }
- } // SchedRW = [WriteSystem]
-} // HasAMXTF32, HasAMXTRANSPOSE
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 5c23f91..6b2a7a4 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -4544,11 +4544,6 @@ static unsigned getLoadStoreRegOpcode(Register Reg,
return Load ? GET_EGPR_IF_ENABLED(X86::TILELOADD)
: GET_EGPR_IF_ENABLED(X86::TILESTORED);
#undef GET_EGPR_IF_ENABLED
- case 2048:
- assert(X86::TILEPAIRRegClass.hasSubClassEq(RC) &&
- "Unknown 2048-byte regclass");
- assert(STI.hasAMXTILE() && "Using 2048-bit register requires AMX-TILE");
- return Load ? X86::PTILEPAIRLOAD : X86::PTILEPAIRSTORE;
}
}
@@ -4743,8 +4738,6 @@ static bool isAMXOpcode(unsigned Opc) {
case X86::TILESTORED:
case X86::TILELOADD_EVEX:
case X86::TILESTORED_EVEX:
- case X86::PTILEPAIRLOAD:
- case X86::PTILEPAIRSTORE:
return true;
}
}
@@ -4757,8 +4750,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB,
default:
llvm_unreachable("Unexpected special opcode!");
case X86::TILESTORED:
- case X86::TILESTORED_EVEX:
- case X86::PTILEPAIRSTORE: {
+ case X86::TILESTORED_EVEX: {
// tilestored %tmm, (%sp, %idx)
MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();
Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass);
@@ -4772,8 +4764,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB,
break;
}
case X86::TILELOADD:
- case X86::TILELOADD_EVEX:
- case X86::PTILEPAIRLOAD: {
+ case X86::TILELOADD_EVEX: {
// tileloadd (%sp, %idx), %tmm
MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();
Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass);
diff --git a/llvm/lib/Target/X86/X86InstrOperands.td b/llvm/lib/Target/X86/X86InstrOperands.td
index 5207eca..6ba07f7 100644
--- a/llvm/lib/Target/X86/X86InstrOperands.td
+++ b/llvm/lib/Target/X86/X86InstrOperands.td
@@ -536,10 +536,3 @@ def VK8Pair : RegisterOperand<VK8PAIR, "printVKPair"> {
def VK16Pair : RegisterOperand<VK16PAIR, "printVKPair"> {
let ParserMatchClass = VK16PairAsmOperand;
}
-
-let RenderMethod = "addTILEPairOperands" in
- def TILEPairAsmOperand : AsmOperandClass { let Name = "TILEPair"; }
-
-def TILEPair : RegisterOperand<TILEPAIR, "printTILEPair"> {
- let ParserMatchClass = TILEPairAsmOperand;
-}
diff --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td
index c20bb05..98104a6f 100644
--- a/llvm/lib/Target/X86/X86InstrPredicates.td
+++ b/llvm/lib/Target/X86/X86InstrPredicates.td
@@ -183,7 +183,6 @@ def HasAMXINT8 : Predicate<"Subtarget->hasAMXINT8()">;
def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">;
def HasAMXFP8 : Predicate<"Subtarget->hasAMXFP8()">;
def HasAMXMOVRS : Predicate<"Subtarget->hasAMXMOVRS()">;
-def HasAMXTRANSPOSE : Predicate<"Subtarget->hasAMXTRANSPOSE()">;
def HasAMXAVX512 : Predicate<"Subtarget->hasAMXAVX512()">;
def HasAMXTF32 : Predicate<"Subtarget->hasAMXTF32()">;
def HasUINTR : Predicate<"Subtarget->hasUINTR()">;
diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index 8ffd454..2fc5d38 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -74,22 +74,6 @@ static bool isAMXCast(Instruction *II) {
match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
}
-// Some instructions may return more than one tiles.
-// e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal
-static unsigned getNumDefTiles(IntrinsicInst *II) {
- Type *Ty = II->getType();
- if (Ty->isX86_AMXTy())
- return 1;
-
- unsigned Num = 0;
- for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) {
- Type *STy = Ty->getContainedType(i);
- if (STy->isX86_AMXTy())
- Num++;
- }
- return Num;
-}
-
static bool isAMXIntrinsic(Value *I) {
auto *II = dyn_cast<IntrinsicInst>(I);
if (!II)
@@ -98,7 +82,7 @@ static bool isAMXIntrinsic(Value *I) {
return false;
// Check if return type or parameter is x86_amx. If it is x86_amx
// the intrinsic must be x86 amx intrinsics.
- if (getNumDefTiles(II) > 0)
+ if (II->getType()->isX86_AMXTy())
return true;
for (Value *V : II->args()) {
if (V->getType()->isX86_AMXTy())
@@ -137,27 +121,7 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
llvm_unreachable("No terminator in the entry block!");
}
-class ShapeCalculator {
-private:
- const TargetMachine *TM = nullptr;
-
- // In AMX intrinsics we let Shape = {Row, Col}, but the
- // RealCol = Col / ElementSize. We may use the RealCol
- // as a new Row for other new created AMX intrinsics.
- std::map<Value *, Value *> Col2Row, Row2Col;
-
-public:
- ShapeCalculator(const TargetMachine *TargetM) : TM(TargetM) {}
- std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
- std::pair<Value *, Value *> getShape(PHINode *Phi);
- Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
- Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity);
-};
-
-Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V,
- unsigned Granularity) {
- if (auto It = Col2Row.find(V); It != Col2Row.end())
- return It->second;
+static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) {
IRBuilder<> Builder(II);
Value *RealRow = nullptr;
if (isa<ConstantInt>(V))
@@ -186,47 +150,16 @@ Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V,
getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
}
- Col2Row[V] = RealRow;
return RealRow;
}
-Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V,
- unsigned Granularity) {
- if (auto It = Row2Col.find(V); It != Row2Col.end())
- return It->second;
- IRBuilder<> Builder(II);
- Value *RealCol = nullptr;
- if (isa<ConstantInt>(V))
- RealCol =
- Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity);
- else if (isa<Instruction>(V)) {
- Builder.SetInsertPoint(cast<Instruction>(V));
- RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity));
- cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V));
- } else {
- // When it is not a const value and it is a function argument, we create
- // Row at the entry bb.
- IRBuilder<> NewBuilder(
- getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
- RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity));
- }
- Row2Col[V] = RealCol;
- return RealCol;
-}
-
// TODO: Refine the row and col-in-bytes of tile to row and col of matrix.
-std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
- unsigned OpNo) {
- (void)TM;
+std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
IRBuilder<> Builder(II);
Value *Row = nullptr, *Col = nullptr;
switch (II->getIntrinsicID()) {
default:
llvm_unreachable("Expect amx intrinsics");
- case Intrinsic::x86_t2rpntlvwz0_internal:
- case Intrinsic::x86_t2rpntlvwz0t1_internal:
- case Intrinsic::x86_t2rpntlvwz1_internal:
- case Intrinsic::x86_t2rpntlvwz1t1_internal:
case Intrinsic::x86_tileloadd64_internal:
case Intrinsic::x86_tileloaddt164_internal:
case Intrinsic::x86_tilestored64_internal:
@@ -271,13 +204,6 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
}
break;
}
- case Intrinsic::x86_ttransposed_internal:
- case Intrinsic::x86_tconjtfp16_internal: {
- assert((OpNo == 2) && "Illegal Operand Number.");
- Row = getRowFromCol(II, II->getArgOperand(1), 4);
- Col = getColFromRow(II, II->getArgOperand(0), 4);
- break;
- }
case Intrinsic::x86_tcvtrowd2ps_internal:
case Intrinsic::x86_tcvtrowps2bf16h_internal:
case Intrinsic::x86_tcvtrowps2bf16l_internal:
@@ -289,34 +215,12 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
Col = II->getArgOperand(1);
break;
}
- case Intrinsic::x86_ttdpbf16ps_internal:
- case Intrinsic::x86_ttdpfp16ps_internal:
- case Intrinsic::x86_ttcmmimfp16ps_internal:
- case Intrinsic::x86_ttcmmrlfp16ps_internal:
- case Intrinsic::x86_tconjtcmmimfp16ps_internal:
- case Intrinsic::x86_ttmmultf32ps_internal: {
- switch (OpNo) {
- case 3:
- Row = II->getArgOperand(0);
- Col = II->getArgOperand(1);
- break;
- case 4:
- Row = getRowFromCol(II, II->getArgOperand(2), 4);
- Col = getColFromRow(II, II->getArgOperand(0), 4);
- break;
- case 5:
- Row = getRowFromCol(II, II->getArgOperand(2), 4);
- Col = II->getArgOperand(1);
- break;
- }
- break;
- }
}
return std::make_pair(Row, Col);
}
-std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) {
+static std::pair<Value *, Value *> getShape(PHINode *Phi) {
Use &U = *(Phi->use_begin());
unsigned OpNo = U.getOperandNo();
User *V = U.getUser();
@@ -349,15 +253,14 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) {
namespace {
class X86LowerAMXType {
Function &Func;
- ShapeCalculator *SC;
// In AMX intrinsics we let Shape = {Row, Col}, but the
// RealCol = Col / ElementSize. We may use the RealCol
// as a new Row for other new created AMX intrinsics.
- std::map<Value *, Value *> Col2Row, Row2Col;
+ std::map<Value *, Value *> Col2Row;
public:
- X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {}
+ X86LowerAMXType(Function &F) : Func(F) {}
bool visit();
void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
@@ -374,7 +277,7 @@ void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
Use &U = *(Bitcast->use_begin());
unsigned OpNo = U.getOperandNo();
auto *II = cast<IntrinsicInst>(U.getUser());
- std::tie(Row, Col) = SC->getShape(II, OpNo);
+ std::tie(Row, Col) = getShape(II, OpNo);
IRBuilder<> Builder(Bitcast);
// Use the maximun column as stride.
Value *Stride = Builder.getInt64(64);
@@ -454,7 +357,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
Builder.CreateStore(Src, AllocaAddr);
// TODO we can pick an constant operand for the shape.
Value *Row = nullptr, *Col = nullptr;
- std::tie(Row, Col) = SC->getShape(II, OpNo);
+ std::tie(Row, Col) = getShape(II, OpNo);
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
Value *NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
@@ -594,18 +497,11 @@ static Value *getAllocaPos(BasicBlock *BB) {
static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
- auto *II = dyn_cast<IntrinsicInst>(TileDef);
- unsigned Idx = 0;
- // Extract tile from multiple tiles' def.
- if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) {
- assert(Extr->hasIndices() && "Tile extract miss index!");
- Idx = Extr->getIndices()[0];
- II = cast<IntrinsicInst>(Extr->getOperand(0));
- }
+ auto *II = cast<IntrinsicInst>(TileDef);
assert(II && "Not tile intrinsic!");
- Value *Row = II->getOperand(Idx);
- Value *Col = II->getOperand(Idx + 1);
+ Value *Row = II->getOperand(0);
+ Value *Col = II->getOperand(1);
BasicBlock *BB = TileDef->getParent();
BasicBlock::iterator Iter = TileDef->getIterator();
@@ -624,20 +520,14 @@ static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
// Get tile shape.
IntrinsicInst *II = nullptr;
- unsigned Idx = 0;
if (IsPHI) {
Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
II = cast<IntrinsicInst>(PhiOp);
- } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) {
- // Extract tile from multiple tiles' def.
- assert(Extr->hasIndices() && "Tile extract miss index!");
- Idx = Extr->getIndices()[0];
- II = cast<IntrinsicInst>(Extr->getOperand(0));
} else {
II = cast<IntrinsicInst>(V);
}
- Value *Row = II->getOperand(Idx);
- Value *Col = II->getOperand(Idx + 1);
+ Value *Row = II->getOperand(0);
+ Value *Col = II->getOperand(1);
Instruction *UserI = cast<Instruction>(U.getUser());
IRBuilder<> Builder(UserI);
@@ -848,12 +738,10 @@ namespace {
class X86LowerAMXCast {
Function &Func;
- ShapeCalculator *SC;
std::unique_ptr<DominatorTree> DT;
public:
- X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC)
- : Func(F), SC(ShapeC), DT(nullptr) {}
+ X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
bool combineTilezero(IntrinsicInst *Cast);
@@ -932,7 +820,7 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(
if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
return false;
Value *Row = nullptr, *Col = nullptr;
- std::tie(Row, Col) = SC->getShape(OldPN);
+ std::tie(Row, Col) = getShape(OldPN);
// TODO: If it is not constant the Row and Col must domoniate tilezero
// that we are going to create.
if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
@@ -1063,19 +951,6 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(
return true;
}
-static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx,
- bool IsRow) {
- if (!isAMXIntrinsic(Inst))
- return nullptr;
-
- auto *II = cast<IntrinsicInst>(Inst);
- if (IsRow)
- return II->getOperand(0);
-
- assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!");
- return II->getOperand(ShapeIdx + 1);
-}
-
// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
// store <256 x i32> %43, <256 x i32>* %p, align 64
// -->
@@ -1090,38 +965,13 @@ bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
if (!Tile->hasOneUse())
return false;
- // We don't fetch shape from tilestore, we only get shape from tiledef,
- // so we can set the max tile shape to tilestore for special cases.
+ auto *II = cast<IntrinsicInst>(Tile);
+ // Tile is output from AMX intrinsic. The first operand of the
+ // intrinsic is row, the second operand of the intrinsic is column.
+ Value *Row = II->getOperand(0);
+ Value *Col = II->getOperand(1);
+
IRBuilder<> Builder(ST);
- Value *Row = nullptr;
- Value *Col = nullptr;
-
- if (isAMXIntrinsic(Tile)) {
- auto *II = cast<IntrinsicInst>(Tile);
- // Tile is output from AMX intrinsic. The first operand of the
- // intrinsic is row, the second operand of the intrinsic is column.
- Row = II->getOperand(0);
- Col = II->getOperand(1);
- } else {
- // Now we supported multi-tiles value in structure, so we may get tile
- // from extracting multi-tiles structure.
- // For example:
- // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1,
- // i16 %2, i16 %3, i8* %4, i64 %5)
- // %7 = extractvalue { x86_amx, x86_amx } %6, 0
- // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7)
- // store <256 x i32> %8, <256 x i32>* %0, align 1024
- //
- // TODO: Currently we only handle extractvalue case, enhance me for other
- // cases if possible.
- auto *II = cast<ExtractValueInst>(Tile);
- assert(II && "We meet unhandle source in fetching tile value!");
- unsigned ShapeIdx = II->getIndices()[0];
- Value *Tiles = II->getOperand(0);
- Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true);
- Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false);
- }
- assert(Row && Col && "Shape got failed!");
// Stride should be equal to col(measured by bytes)
Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
@@ -1146,7 +996,7 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
// shape information through def-use chain.
if (!isAMXIntrinsic(II))
return false;
- std::tie(Row, Col) = SC->getShape(II, OpNo);
+ std::tie(Row, Col) = getShape(II, OpNo);
IRBuilder<> Builder(LD);
// Stride should be equal to col(measured by bytes)
Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
@@ -1189,7 +1039,7 @@ bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
if (!isAMXIntrinsic(II))
return false;
- std::tie(Row, Col) = SC->getShape(II, OpNo);
+ std::tie(Row, Col) = getShape(II, OpNo);
IRBuilder<> Builder(Cast);
Value *NewInst =
@@ -1384,7 +1234,7 @@ bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
Builder.CreateStore(Src, AllocaAddr);
// TODO we can pick an constant operand for the shape.
Value *Row = nullptr, *Col = nullptr;
- std::tie(Row, Col) = SC->getShape(II, OpNo);
+ std::tie(Row, Col) = getShape(II, OpNo);
std::array<Value *, 4> Args = {
Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
Value *NewInst =
@@ -1445,14 +1295,13 @@ bool lowerAmxType(Function &F, const TargetMachine *TM,
return false;
bool C = false;
- ShapeCalculator SC(TM);
- X86LowerAMXCast LAC(F, &SC);
+ X86LowerAMXCast LAC(F);
C |= LAC.combineAMXcast(TLI);
// There might be remaining AMXcast after combineAMXcast and they should be
// handled elegantly.
C |= LAC.transformAllAMXCast();
- X86LowerAMXType LAT(F, &SC);
+ X86LowerAMXType LAT(F);
C |= LAT.visit();
// Prepare for fast register allocation at O0.
diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp
index 2a1c499..8a1d00d 100644
--- a/llvm/lib/Target/X86/X86PreTileConfig.cpp
+++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -141,15 +141,10 @@ class X86PreTileConfig : public MachineFunctionPass {
if (!MO.isReg() || !MO.getReg().isVirtual())
return false;
- unsigned Shapes = 0;
- if (MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID)
- Shapes = 1;
- if (MRI->getRegClass(MO.getReg())->getID() == X86::TILEPAIRRegClassID)
- Shapes = 2;
- if (!Shapes)
+ if (MRI->getRegClass(MO.getReg())->getID() != X86::TILERegClassID)
return false;
- collectShapeInfo(MI, Shapes);
+ collectShapeInfo(MI);
return true;
}
@@ -165,7 +160,7 @@ class X86PreTileConfig : public MachineFunctionPass {
}
/// Collect the shape def information for later use.
- void collectShapeInfo(MachineInstr &MI, unsigned Shapes);
+ void collectShapeInfo(MachineInstr &MI);
/// Try to hoist shapes definded below AMX instructions.
bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
@@ -231,7 +226,7 @@ INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
"Tile Register Pre-configure", false, false)
-void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {
+void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
MIRef MIR(MI, MBB);
auto &Refs = ShapeBBs[MBB];
@@ -240,10 +235,8 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {
Refs.insert(I, MIR);
};
- // All shapes have same row in multi-tile operand.
- SmallVector<Register, 8> WorkList;
- for (unsigned I = 1; I < Shapes + 2; ++I)
- WorkList.push_back(MI.getOperand(I).getReg());
+ SmallVector<Register, 8> WorkList(
+ {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
while (!WorkList.empty()) {
Register R = WorkList.pop_back_val();
MachineInstr *DefMI = MRI->getVRegDef(R);
@@ -252,13 +245,6 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {
if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
continue;
- // This happens when column = 0 in multi-tile operand.
- if (DefMI->getOpcode() == X86::COPY) {
- MachineInstr *MI = MRI->getVRegDef(DefMI->getOperand(1).getReg());
- if (MI && MI->isMoveImmediate())
- continue;
- }
-
if (DefMI->isPHI()) {
for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp
index 76979e3..72f3813 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -597,10 +597,6 @@ BitVector X86RegisterInfo::getReservedRegs(const MachineFunction &MF) const {
Reserved.set(*AI);
}
- // Reserve low half pair registers in case they are used by RA aggressively.
- Reserved.set(X86::TMM0_TMM1);
- Reserved.set(X86::TMM2_TMM3);
-
assert(checkAllSuperRegsMarked(Reserved,
{X86::SIL, X86::DIL, X86::BPL, X86::SPL,
X86::SIH, X86::DIH, X86::BPH, X86::SPH}));
@@ -621,7 +617,7 @@ unsigned X86RegisterInfo::getNumSupportedRegs(const MachineFunction &MF) const {
// and try to return the minimum number of registers supported by the target.
static_assert((X86::R15WH + 1 == X86::YMM0) && (X86::YMM15 + 1 == X86::K0) &&
(X86::K6_K7 + 1 == X86::TMMCFG) &&
- (X86::TMM6_TMM7 + 1 == X86::R16) &&
+ (X86::TMM7 + 1 == X86::R16) &&
(X86::R31WH + 1 == X86::NUM_TARGET_REGS),
"Register number may be incorrect");
@@ -694,8 +690,7 @@ bool X86RegisterInfo::isFixedRegister(const MachineFunction &MF,
}
bool X86RegisterInfo::isTileRegisterClass(const TargetRegisterClass *RC) const {
- return RC->getID() == X86::TILERegClassID ||
- RC->getID() == X86::TILEPAIRRegClassID;
+ return RC->getID() == X86::TILERegClassID;
}
void X86RegisterInfo::adjustStackMapLiveOutMask(uint32_t *Mask) const {
@@ -1062,17 +1057,9 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
case X86::PTDPFP16PSV:
case X86::PTCMMIMFP16PSV:
case X86::PTCMMRLFP16PSV:
- case X86::PTTRANSPOSEDV:
- case X86::PTTDPBF16PSV:
- case X86::PTTDPFP16PSV:
- case X86::PTTCMMIMFP16PSV:
- case X86::PTTCMMRLFP16PSV:
- case X86::PTCONJTCMMIMFP16PSV:
- case X86::PTCONJTFP16V:
case X86::PTILELOADDRSV:
case X86::PTILELOADDRST1V:
case X86::PTMMULTF32PSV:
- case X86::PTTMMULTF32PSV:
case X86::PTDPBF8PSV:
case X86::PTDPBHF8PSV:
case X86::PTDPHBF8PSV:
@@ -1083,56 +1070,7 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
VRM->assignVirt2Shape(VirtReg, Shape);
return Shape;
}
- case X86::PT2RPNTLVWZ0V:
- case X86::PT2RPNTLVWZ0T1V:
- case X86::PT2RPNTLVWZ1V:
- case X86::PT2RPNTLVWZ1T1V:
- case X86::PT2RPNTLVWZ0RSV:
- case X86::PT2RPNTLVWZ0RST1V:
- case X86::PT2RPNTLVWZ1RSV:
- case X86::PT2RPNTLVWZ1RST1V: {
- MachineOperand &MO1 = MI->getOperand(1);
- MachineOperand &MO2 = MI->getOperand(2);
- MachineOperand &MO3 = MI->getOperand(3);
- ShapeT Shape({&MO1, &MO2, &MO1, &MO3}, MRI);
- VRM->assignVirt2Shape(VirtReg, Shape);
- return Shape;
- }
- }
-}
-
-static bool canHintShape(ShapeT &PhysShape, ShapeT &VirtShape) {
- unsigned PhysShapeNum = PhysShape.getShapeNum();
- unsigned VirtShapeNum = VirtShape.getShapeNum();
-
- if (PhysShapeNum < VirtShapeNum)
- return false;
-
- if (PhysShapeNum == VirtShapeNum) {
- if (PhysShapeNum == 1)
- return PhysShape == VirtShape;
-
- for (unsigned I = 0; I < PhysShapeNum; I++) {
- ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I));
- ShapeT VShape(VirtShape.getRow(I), VirtShape.getCol(I));
- if (VShape != PShape)
- return false;
- }
- return true;
- }
-
- // Hint subreg of mult-tile reg to single tile reg.
- if (VirtShapeNum == 1) {
- for (unsigned I = 0; I < PhysShapeNum; I++) {
- ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I));
- if (VirtShape == PShape)
- return true;
- }
}
-
- // Note: Currently we have no requirement for case of
- // (VirtShapeNum > 1 and PhysShapeNum > VirtShapeNum)
- return false;
}
bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,
@@ -1153,7 +1091,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,
if (!VRM)
return BaseImplRetVal;
- if (ID != X86::TILERegClassID && ID != X86::TILEPAIRRegClassID) {
+ if (ID != X86::TILERegClassID) {
if (DisableRegAllocNDDHints || !ST.hasNDD() ||
!TRI.isGeneralPurposeRegisterClass(&RC))
return BaseImplRetVal;
@@ -1204,7 +1142,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,
return;
}
ShapeT PhysShape = getTileShape(VReg, const_cast<VirtRegMap *>(VRM), MRI);
- if (canHintShape(PhysShape, VirtShape))
+ if (PhysShape == VirtShape)
Hints.push_back(PhysReg);
};
diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td
index 99b7910..692e42a 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.td
+++ b/llvm/lib/Target/X86/X86RegisterInfo.td
@@ -30,8 +30,6 @@ let Namespace = "X86" in {
def sub_ymm : SubRegIndex<256>;
def sub_mask_0 : SubRegIndex<-1>;
def sub_mask_1 : SubRegIndex<-1, -1>;
- def sub_t0 : SubRegIndex<8192>;
- def sub_t1 : SubRegIndex<8192, 8192>;
}
//===----------------------------------------------------------------------===//
@@ -432,10 +430,6 @@ def TMM4: X86Reg<"tmm4", 4>;
def TMM5: X86Reg<"tmm5", 5>;
def TMM6: X86Reg<"tmm6", 6>;
def TMM7: X86Reg<"tmm7", 7>;
-// TMM register pairs
-def TPAIRS : RegisterTuples<[sub_t0, sub_t1],
- [(add TMM0, TMM2, TMM4, TMM6),
- (add TMM1, TMM3, TMM5, TMM7)]>;
}
// Floating point stack registers. These don't map one-to-one to the FP
@@ -862,9 +856,6 @@ def VK64WM : RegisterClass<"X86", [v64i1], 64, (add VK32WM)> {let Size = 64;}
let CopyCost = -1 in // Don't allow copying of tile registers
def TILE : RegisterClass<"X86", [x86amx], 8192,
(sequence "TMM%u", 0, 7)> {let Size = 8192;}
-// Need check alignment 3rd operand size=1024*2*8
-let isAllocatable = 1 in
-def TILEPAIR : RegisterClass<"X86", [untyped], 512, (add TPAIRS)> {let Size = 16384;}
//===----------------------------------------------------------------------===//
// Register categories.
diff --git a/llvm/lib/Target/X86/X86TileConfig.cpp b/llvm/lib/Target/X86/X86TileConfig.cpp
index 17a44dd..09ef8fb 100644
--- a/llvm/lib/Target/X86/X86TileConfig.cpp
+++ b/llvm/lib/Target/X86/X86TileConfig.cpp
@@ -74,63 +74,6 @@ INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,
false)
-unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) {
- if (Reg.isVirtual()) {
- unsigned RegClassID = MRI->getRegClass(Reg)->getID();
- if (RegClassID == X86::TILERegClassID)
- return 1;
- if (RegClassID == X86::TILEPAIRRegClassID)
- return 2;
- } else {
- if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
- return 1;
- if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
- return 2;
- }
- return 0;
-}
-
-static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM,
- Register VirtReg,
- SmallVector<ShapeT, 8> &Phys2Shapes) {
- unsigned Num = getAMXRegNum(MRI, VirtReg);
- MCRegister PhysReg = VRM.getPhys(VirtReg);
- if (!PhysReg)
- return;
-
- if (Num == 1) {
- unsigned Index = PhysReg - X86::TMM0;
- if (!Phys2Shapes[Index].isValid()) {
- ShapeT Shape = VRM.getShape(VirtReg);
- Phys2Shapes[Index] = std::move(Shape);
- return;
- }
- }
- // Split tile pair shape info to 2 single tile shape info. e.g:
- // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes.
- if (Num == 2) {
- unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2;
- unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1;
-
- ShapeT Shape = VRM.getShape(VirtReg);
- assert(Shape.getShapeNum() == 2 && "Unexpected shape number!");
-
- if (!Phys2Shapes[Index0].isValid()) {
- ShapeT Shape0(Shape.getRow(0), Shape.getCol(0), MRI);
- Phys2Shapes[Index0] = std::move(Shape0);
- }
-
- if (!Phys2Shapes[Index1].isValid()) {
- ShapeT Shape1(Shape.getRow(1), Shape.getCol(1), MRI);
- Phys2Shapes[Index1] = std::move(Shape1);
- }
- }
-}
-
-static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) {
- return getAMXRegNum(MRI, Reg) > 0;
-}
-
bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
// Early exit in the common case of non-AMX code.
@@ -138,7 +81,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
return false;
const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
- const TargetRegisterInfo *TRI = ST.getRegisterInfo();
+ const X86RegisterInfo *TRI = ST.getRegisterInfo();
const TargetInstrInfo *TII = ST.getInstrInfo();
MachineRegisterInfo &MRI = MF.getRegInfo();
LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
@@ -176,24 +119,29 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
assert(ConstMI && "Cannot find an insertion point");
unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs();
- SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT());
+ SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);
for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
Register VirtReg = Register::index2VirtReg(I);
if (MRI.reg_nodbg_empty(VirtReg))
continue;
- if (!isAMXRegClass(&MRI, VirtReg))
+ if (!TRI->isTileRegisterClass(MRI.getRegClass(VirtReg)))
+ continue;
+ MCRegister PhysReg = VRM.getPhys(VirtReg);
+ if (!PhysReg)
continue;
- collectVirtRegShapes(&MRI, VRM, VirtReg, Phys2Shapes);
+ unsigned Index = PhysReg - X86::TMM0;
+ if (!Phys2Virt[Index])
+ Phys2Virt[Index] = VirtReg;
}
// Fill in the shape of each tile physical register.
for (unsigned I = 0; I < AMXRegNum; ++I) {
- ShapeT Shape = Phys2Shapes[I];
- if (!Shape.isValid())
+ if (!Phys2Virt[I])
continue;
DebugLoc DL;
bool IsRow = true;
MachineInstr *NewMI = nullptr;
+ ShapeT Shape = VRM.getShape(Phys2Virt[I]);
for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {
// Here is the data format for the tile config.
// 0 palette
@@ -222,14 +170,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
"Cannot initialize with different shapes");
continue;
}
- if (DefMI.getOperand(1).isImm()) {
- Imm = DefMI.getOperand(1).getImm();
- } else {
- assert(DefMI.getOpcode() == X86::MOV32r0 &&
- "The opcode is assumed to be MOV32r0 if the operand is not "
- "immediate.");
- Imm = 0;
- }
+ Imm = DefMI.getOperand(1).getImm();
NewMI = addFrameReference(
BuildMI(MF.front(), ++ConstMI->getIterator(), DL,