diff options
Diffstat (limited to 'llvm/lib/Target')
73 files changed, 1281 insertions, 867 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64FMV.td b/llvm/lib/Target/AArch64/AArch64FMV.td index fc7a94a..e0f56fd 100644 --- a/llvm/lib/Target/AArch64/AArch64FMV.td +++ b/llvm/lib/Target/AArch64/AArch64FMV.td @@ -22,64 +22,65 @@ // Something you can add to target_version or target_clones. -class FMVExtension<string n, string b, int p> { +class FMVExtension<string name, string enumeration> { // Name, as spelled in target_version or target_clones. e.g. "memtag". - string Name = n; + string Name = name; // A C++ expression giving the number of the bit in the FMV ABI. // Currently this is given as a value from the enum "CPUFeatures". - string Bit = b; + string FeatureBit = "FEAT_" # enumeration; // SubtargetFeature enabled for codegen when this FMV feature is present. - string BackendFeature = n; + string BackendFeature = name; - // The FMV priority. - int Priority = p; + // A C++ expression giving the number of the priority bit. + // Currently this is given as a value from the enum "FeatPriorities". + string PriorityBit = "PRIOR_" # enumeration; } -def : FMVExtension<"aes", "FEAT_PMULL", 150>; -def : FMVExtension<"bf16", "FEAT_BF16", 280>; -def : FMVExtension<"bti", "FEAT_BTI", 510>; -def : FMVExtension<"crc", "FEAT_CRC", 110>; -def : FMVExtension<"dit", "FEAT_DIT", 180>; -def : FMVExtension<"dotprod", "FEAT_DOTPROD", 104>; -let BackendFeature = "ccpp" in def : FMVExtension<"dpb", "FEAT_DPB", 190>; -let BackendFeature = "ccdp" in def : FMVExtension<"dpb2", "FEAT_DPB2", 200>; -def : FMVExtension<"f32mm", "FEAT_SVE_F32MM", 350>; -def : FMVExtension<"f64mm", "FEAT_SVE_F64MM", 360>; -def : FMVExtension<"fcma", "FEAT_FCMA", 220>; -def : FMVExtension<"flagm", "FEAT_FLAGM", 20>; -let BackendFeature = "altnzcv" in def : FMVExtension<"flagm2", "FEAT_FLAGM2", 30>; -def : FMVExtension<"fp", "FEAT_FP", 90>; -def : FMVExtension<"fp16", "FEAT_FP16", 170>; -def : FMVExtension<"fp16fml", "FEAT_FP16FML", 175>; -let BackendFeature = "fptoint" in def : FMVExtension<"frintts", "FEAT_FRINTTS", 250>; -def : FMVExtension<"i8mm", "FEAT_I8MM", 270>; -def : FMVExtension<"jscvt", "FEAT_JSCVT", 210>; -def : FMVExtension<"ls64", "FEAT_LS64_ACCDATA", 520>; -def : FMVExtension<"lse", "FEAT_LSE", 80>; -def : FMVExtension<"memtag", "FEAT_MEMTAG2", 440>; -def : FMVExtension<"mops", "FEAT_MOPS", 650>; -def : FMVExtension<"predres", "FEAT_PREDRES", 480>; -def : FMVExtension<"rcpc", "FEAT_RCPC", 230>; -let BackendFeature = "rcpc-immo" in def : FMVExtension<"rcpc2", "FEAT_RCPC2", 240>; -def : FMVExtension<"rcpc3", "FEAT_RCPC3", 241>; -def : FMVExtension<"rdm", "FEAT_RDM", 108>; -def : FMVExtension<"rng", "FEAT_RNG", 10>; -def : FMVExtension<"sb", "FEAT_SB", 470>; -def : FMVExtension<"sha2", "FEAT_SHA2", 130>; -def : FMVExtension<"sha3", "FEAT_SHA3", 140>; -def : FMVExtension<"simd", "FEAT_SIMD", 100>; -def : FMVExtension<"sm4", "FEAT_SM4", 106>; -def : FMVExtension<"sme", "FEAT_SME", 430>; -def : FMVExtension<"sme-f64f64", "FEAT_SME_F64", 560>; -def : FMVExtension<"sme-i16i64", "FEAT_SME_I64", 570>; -def : FMVExtension<"sme2", "FEAT_SME2", 580>; -def : FMVExtension<"ssbs", "FEAT_SSBS2", 490>; -def : FMVExtension<"sve", "FEAT_SVE", 310>; -def : FMVExtension<"sve2", "FEAT_SVE2", 370>; -def : FMVExtension<"sve2-aes", "FEAT_SVE_PMULL128", 380>; -def : FMVExtension<"sve2-bitperm", "FEAT_SVE_BITPERM", 400>; -def : FMVExtension<"sve2-sha3", "FEAT_SVE_SHA3", 410>; -def : FMVExtension<"sve2-sm4", "FEAT_SVE_SM4", 420>; -def : FMVExtension<"wfxt", "FEAT_WFXT", 550>; +def : FMVExtension<"aes", "PMULL">; +def : FMVExtension<"bf16", "BF16">; +def : FMVExtension<"bti", "BTI">; +def : FMVExtension<"crc", "CRC">; +def : FMVExtension<"dit", "DIT">; +def : FMVExtension<"dotprod", "DOTPROD">; +let BackendFeature = "ccpp" in def : FMVExtension<"dpb", "DPB">; +let BackendFeature = "ccdp" in def : FMVExtension<"dpb2", "DPB2">; +def : FMVExtension<"f32mm", "SVE_F32MM">; +def : FMVExtension<"f64mm", "SVE_F64MM">; +def : FMVExtension<"fcma", "FCMA">; +def : FMVExtension<"flagm", "FLAGM">; +let BackendFeature = "altnzcv" in def : FMVExtension<"flagm2", "FLAGM2">; +def : FMVExtension<"fp", "FP">; +def : FMVExtension<"fp16", "FP16">; +def : FMVExtension<"fp16fml", "FP16FML">; +let BackendFeature = "fptoint" in def : FMVExtension<"frintts", "FRINTTS">; +def : FMVExtension<"i8mm", "I8MM">; +def : FMVExtension<"jscvt", "JSCVT">; +def : FMVExtension<"ls64", "LS64_ACCDATA">; +def : FMVExtension<"lse", "LSE">; +def : FMVExtension<"memtag", "MEMTAG2">; +def : FMVExtension<"mops", "MOPS">; +def : FMVExtension<"predres", "PREDRES">; +def : FMVExtension<"rcpc", "RCPC">; +let BackendFeature = "rcpc-immo" in def : FMVExtension<"rcpc2", "RCPC2">; +def : FMVExtension<"rcpc3", "RCPC3">; +def : FMVExtension<"rdm", "RDM">; +def : FMVExtension<"rng", "RNG">; +def : FMVExtension<"sb", "SB">; +def : FMVExtension<"sha2", "SHA2">; +def : FMVExtension<"sha3", "SHA3">; +def : FMVExtension<"simd", "SIMD">; +def : FMVExtension<"sm4", "SM4">; +def : FMVExtension<"sme", "SME">; +def : FMVExtension<"sme-f64f64", "SME_F64">; +def : FMVExtension<"sme-i16i64", "SME_I64">; +def : FMVExtension<"sme2", "SME2">; +def : FMVExtension<"ssbs", "SSBS2">; +def : FMVExtension<"sve", "SVE">; +def : FMVExtension<"sve2", "SVE2">; +def : FMVExtension<"sve2-aes", "SVE_PMULL128">; +def : FMVExtension<"sve2-bitperm", "SVE_BITPERM">; +def : FMVExtension<"sve2-sha3", "SVE_SHA3">; +def : FMVExtension<"sve2-sm4", "SVE_SM4">; +def : FMVExtension<"wfxt", "WFXT">; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index ef00b09..3ad2905 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -753,6 +753,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(Op, MVT::v8bf16, Expand); } + // For bf16, fpextend is custom lowered to be optionally expanded into shifts. + setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Custom); + auto LegalizeNarrowFP = [this](MVT ScalarVT) { for (auto Op : { ISD::SETCC, @@ -893,10 +901,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(Op, MVT::f16, Legal); } - // Strict conversion to a larger type is legal - for (auto VT : {MVT::f32, MVT::f64}) - setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal); - setOperationAction(ISD::PREFETCH, MVT::Other, Custom); setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom); @@ -1183,6 +1187,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setMaxDivRemBitWidthSupported(128); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom); + if (Subtarget->hasSME()) + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i1, Custom); if (Subtarget->isNeonAvailable()) { // FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to @@ -4496,6 +4502,54 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) return LowerFixedLengthFPExtendToSVE(Op, DAG); + bool IsStrict = Op->isStrictFPOpcode(); + SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::f64) { + // FP16->FP32 extends are legal for v32 and v4f32. + if (Op0VT == MVT::f32 || Op0VT == MVT::f16) + return Op; + // Split bf16->f64 extends into two fpextends. + if (Op0VT == MVT::bf16 && IsStrict) { + SDValue Ext1 = + DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {MVT::f32, MVT::Other}, + {Op0, Op.getOperand(0)}); + return DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {VT, MVT::Other}, + {Ext1, Ext1.getValue(1)}); + } + if (Op0VT == MVT::bf16) + return DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), VT, + DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Op0)); + return SDValue(); + } + + if (VT.getScalarType() == MVT::f32) { + // FP16->FP32 extends are legal for v32 and v4f32. + if (Op0VT.getScalarType() == MVT::f16) + return Op; + if (Op0VT.getScalarType() == MVT::bf16) { + SDLoc DL(Op); + EVT IVT = VT.changeTypeToInteger(); + if (!Op0VT.isVector()) { + Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4bf16, Op0); + IVT = MVT::v4i32; + } + + EVT Op0IVT = Op0.getValueType().changeTypeToInteger(); + SDValue Ext = + DAG.getNode(ISD::ANY_EXTEND, DL, IVT, DAG.getBitcast(Op0IVT, Op0)); + SDValue Shift = + DAG.getNode(ISD::SHL, DL, IVT, Ext, DAG.getConstant(16, DL, IVT)); + if (!Op0VT.isVector()) + Shift = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, Shift, + DAG.getConstant(0, DL, MVT::i64)); + Shift = DAG.getBitcast(VT, Shift); + return IsStrict ? DAG.getMergeValues({Shift, Op.getOperand(0)}, DL) + : Shift; + } + return SDValue(); + } + assert(Op.getValueType() == MVT::f128 && "Unexpected lowering"); return SDValue(); } @@ -7343,6 +7397,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::STRICT_FP_ROUND: return LowerFP_ROUND(Op, DAG); case ISD::FP_EXTEND: + case ISD::STRICT_FP_EXTEND: return LowerFP_EXTEND(Op, DAG); case ISD::FRAMEADDR: return LowerFRAMEADDR(Op, DAG); @@ -27429,6 +27484,15 @@ void AArch64TargetLowering::ReplaceNodeResults( Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V)); return; } + case Intrinsic::aarch64_sme_in_streaming_mode: { + SDLoc DL(N); + SDValue Chain = DAG.getEntryNode(); + SDValue RuntimePStateSM = + getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0)); + Results.push_back( + DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, RuntimePStateSM)); + return; + } case Intrinsic::experimental_vector_match: case Intrinsic::get_active_lane_mask: { if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1) diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 47c4c6c..f527f7e 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -1804,7 +1804,9 @@ class TMSystemException<bits<3> op1, string asm, list<dag> pattern> } class APASI : SimpleSystemI<0, (ins GPR64:$Xt), "apas", "\t$Xt">, Sched<[]> { + bits<5> Xt; let Inst{20-5} = 0b0111001110000000; + let Inst{4-0} = Xt; let DecoderNamespace = "APAS"; } @@ -2768,6 +2770,8 @@ class MulHi<bits<3> opc, string asm, SDNode OpNode> let Inst{23-21} = opc; let Inst{20-16} = Rm; let Inst{15} = 0; + let Inst{14-10} = 0b11111; + let Unpredictable{14-10} = 0b11111; let Inst{9-5} = Rn; let Inst{4-0} = Rd; @@ -4920,6 +4924,8 @@ class LoadExclusivePair<bits<2> sz, bit o2, bit L, bit o1, bit o0, bits<5> Rt; bits<5> Rt2; bits<5> Rn; + let Inst{20-16} = 0b11111; + let Unpredictable{20-16} = 0b11111; let Inst{14-10} = Rt2; let Inst{9-5} = Rn; let Inst{4-0} = Rt; @@ -4935,6 +4941,7 @@ class BaseLoadStoreExclusiveLSUI<bits<2> sz, bit L, bit o0, let Inst{31-30} = sz; let Inst{29-23} = 0b0010010; let Inst{22} = L; + let Inst{21} = 0b0; let Inst{15} = o0; } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index ec891ea4..c6f5cdc 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -5123,22 +5123,6 @@ let Predicates = [HasFullFP16] in { //===----------------------------------------------------------------------===// defm FCVT : FPConversion<"fcvt">; -// Helper to get bf16 into fp32. -def cvt_bf16_to_fp32 : - OutPatFrag<(ops node:$Rn), - (f32 (COPY_TO_REGCLASS - (i32 (UBFMWri - (i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)), - node:$Rn, hsub), GPR32)), - (i64 (i32shift_a (i64 16))), - (i64 (i32shift_b (i64 16))))), - FPR32))>; -// Pattern for bf16 -> fp32. -def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))), - (cvt_bf16_to_fp32 FPR16:$Rn)>; -// Pattern for bf16 -> fp64. -def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))), - (FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>; //===----------------------------------------------------------------------===// // Floating point single operand instructions. @@ -8333,8 +8317,6 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))> def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>; def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>; def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>; -// Vector bf16 -> fp32 is implemented morally as a zext + shift. -def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))), (SHLLv4i16 V64:$Rn)>; // Also match an extend from the upper half of a 128 bit source register. def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))), (USHLLv16i8_shift V128:$Rn, (i32 0))>; diff --git a/llvm/lib/Target/AArch64/AArch64SystemOperands.td b/llvm/lib/Target/AArch64/AArch64SystemOperands.td index c76fc8a..355a9d2 100644 --- a/llvm/lib/Target/AArch64/AArch64SystemOperands.td +++ b/llvm/lib/Target/AArch64/AArch64SystemOperands.td @@ -630,7 +630,7 @@ def ExactFPImmValues : GenericEnum { def ExactFPImmsList : GenericTable { let FilterClass = "ExactFPImm"; - let Fields = ["Name", "Enum", "Repr"]; + let Fields = ["Enum", "Repr"]; } def lookupExactFPImmByEnum : SearchIndex { @@ -638,11 +638,6 @@ def lookupExactFPImmByEnum : SearchIndex { let Key = ["Enum"]; } -def lookupExactFPImmByRepr : SearchIndex { - let Table = ExactFPImmsList; - let Key = ["Repr"]; -} - def : ExactFPImm<"zero", "0.0", 0x0>; def : ExactFPImm<"half", "0.5", 0x1>; def : ExactFPImm<"one", "1.0", 0x2>; @@ -998,7 +993,6 @@ defm : TLBI<"VMALLWS2E1OS", 0b100, 0b1000, 0b0101, 0b010, 0>; class SysReg<string name, bits<2> op0, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2> { string Name = name; - string AltName = name; bits<16> Encoding; let Encoding{15-14} = op0; let Encoding{13-11} = op1; @@ -1018,8 +1012,11 @@ def SysRegValues : GenericEnum { def SysRegsList : GenericTable { let FilterClass = "SysReg"; - let Fields = ["Name", "AltName", "Encoding", "Readable", "Writeable", - "Requires"]; + let Fields = ["Name", "Encoding", "Readable", "Writeable", "Requires"]; + + let PrimaryKey = ["Encoding"]; + let PrimaryKeyName = "lookupSysRegByEncoding"; + let PrimaryKeyReturnRange = true; } def lookupSysRegByName : SearchIndex { @@ -1027,11 +1024,6 @@ def lookupSysRegByName : SearchIndex { let Key = ["Name"]; } -def lookupSysRegByEncoding : SearchIndex { - let Table = SysRegsList; - let Key = ["Encoding"]; -} - class RWSysReg<string name, bits<2> op0, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2> : SysReg<name, op0, op1, crn, crm, op2> { @@ -1317,9 +1309,7 @@ def : RWSysReg<"TTBR0_EL1", 0b11, 0b000, 0b0010, 0b0000, 0b000>; def : RWSysReg<"TTBR0_EL3", 0b11, 0b110, 0b0010, 0b0000, 0b000>; let Requires = [{ {AArch64::FeatureEL2VMSA} }] in { -def : RWSysReg<"TTBR0_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b000> { - let AltName = "VSCTLR_EL2"; -} +def : RWSysReg<"TTBR0_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b000>; def : RWSysReg<"VTTBR_EL2", 0b11, 0b100, 0b0010, 0b0001, 0b000>; } @@ -1706,9 +1696,7 @@ def : RWSysReg<"ICH_LR15_EL2", 0b11, 0b100, 0b1100, 0b1101, 0b111>; let Requires = [{ {AArch64::HasV8_0rOps} }] in { //Virtualization System Control Register // Op0 Op1 CRn CRm Op2 -def : RWSysReg<"VSCTLR_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b000> { - let AltName = "TTBR0_EL2"; -} +def : RWSysReg<"VSCTLR_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b000>; //MPU Type Register // Op0 Op1 CRn CRm Op2 @@ -2376,7 +2364,6 @@ def : RWSysReg<"ACTLRALIAS_EL1", 0b11, 0b000, 0b0001, 0b0100, 0b101>; class PHint<bits<2> op0, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2, string name> { string Name = name; - string AltName = name; bits<16> Encoding; let Encoding{15-14} = op0; let Encoding{13-11} = op1; @@ -2394,7 +2381,7 @@ def PHintValues : GenericEnum { def PHintsList : GenericTable { let FilterClass = "PHint"; - let Fields = ["Name", "AltName", "Encoding", "Requires"]; + let Fields = ["Name", "Encoding", "Requires"]; } def lookupPHintByName : SearchIndex { diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 5abe69e..25b6731 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -2761,6 +2761,21 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, return AdjustCost( BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); + static const TypeConversionCostTblEntry BF16Tbl[] = { + {ISD::FP_ROUND, MVT::bf16, MVT::f32, 1}, // bfcvt + {ISD::FP_ROUND, MVT::bf16, MVT::f64, 1}, // bfcvt + {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 1}, // bfcvtn + {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 2}, // bfcvtn+bfcvtn2 + {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 2}, // bfcvtn+fcvtn + {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 3}, // fcvtn+fcvtl2+bfcvtn + {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+bfcvtn + }; + + if (ST->hasBF16()) + if (const auto *Entry = ConvertCostTableLookup( + BF16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) + return AdjustCost(Entry->Cost); + static const TypeConversionCostTblEntry ConversionTbl[] = { {ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1}, // xtn {ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1}, // xtn @@ -2848,6 +2863,14 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f16, 2}, // fcvtl+fcvtl {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, 3}, // fcvtl+fcvtl2+fcvtl {ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, 6}, // 2 * fcvtl+fcvtl2+fcvtl + // BF16 (uses shift) + {ISD::FP_EXTEND, MVT::f32, MVT::bf16, 1}, // shl + {ISD::FP_EXTEND, MVT::f64, MVT::bf16, 2}, // shl+fcvt + {ISD::FP_EXTEND, MVT::v4f32, MVT::v4bf16, 1}, // shll + {ISD::FP_EXTEND, MVT::v8f32, MVT::v8bf16, 2}, // shll+shll2 + {ISD::FP_EXTEND, MVT::v2f64, MVT::v2bf16, 2}, // shll+fcvtl + {ISD::FP_EXTEND, MVT::v4f64, MVT::v4bf16, 3}, // shll+fcvtl+fcvtl2 + {ISD::FP_EXTEND, MVT::v8f64, MVT::v8bf16, 6}, // 2 * shll+fcvtl+fcvtl2 // FP Ext and trunc {ISD::FP_ROUND, MVT::f32, MVT::f64, 1}, // fcvt {ISD::FP_ROUND, MVT::v2f32, MVT::v2f64, 1}, // fcvtn @@ -2860,6 +2883,15 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, {ISD::FP_ROUND, MVT::v2f16, MVT::v2f64, 2}, // fcvtn+fcvtn {ISD::FP_ROUND, MVT::v4f16, MVT::v4f64, 3}, // fcvtn+fcvtn2+fcvtn {ISD::FP_ROUND, MVT::v8f16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+fcvtn + // BF16 (more complex, with +bf16 is handled above) + {ISD::FP_ROUND, MVT::bf16, MVT::f32, 8}, // Expansion is ~8 insns + {ISD::FP_ROUND, MVT::bf16, MVT::f64, 9}, // fcvtn + above + {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f32, 8}, + {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 8}, + {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 15}, + {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 9}, + {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 10}, + {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 19}, // LowerVectorINT_TO_FP: {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1}, @@ -4706,10 +4738,21 @@ InstructionCost AArch64TTIImpl::getShuffleCost( } Kind = improveShuffleKindFromMask(Kind, Mask, Tp, Index, SubTp); - // Treat extractsubvector as single op permutation. bool IsExtractSubvector = Kind == TTI::SK_ExtractSubvector; - if (IsExtractSubvector && LT.second.isFixedLengthVector()) + // A sebvector extract can be implemented with a ext (or trivial extract, if + // from lane 0). This currently only handles low or high extracts to prevent + // SLP vectorizer regressions. + if (IsExtractSubvector && LT.second.isFixedLengthVector()) { + if (LT.second.is128BitVector() && + cast<FixedVectorType>(SubTp)->getNumElements() == + LT.second.getVectorNumElements() / 2) { + if (Index == 0) + return 0; + if (Index == (int)LT.second.getVectorNumElements() / 2) + return 1; + } Kind = TTI::SK_PermuteSingleSrc; + } // Check for broadcast loads, which are supported by the LD1R instruction. // In terms of code-size, the shuffle vector is free when a load + dup get diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp index ae84bc9..875b505 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp @@ -1874,26 +1874,25 @@ void AArch64InstPrinter::printBarriernXSOption(const MCInst *MI, unsigned OpNo, markup(O, Markup::Immediate) << "#" << Val; } -static bool isValidSysReg(const AArch64SysReg::SysReg *Reg, bool Read, +static bool isValidSysReg(const AArch64SysReg::SysReg &Reg, bool Read, const MCSubtargetInfo &STI) { - return (Reg && (Read ? Reg->Readable : Reg->Writeable) && - Reg->haveFeatures(STI.getFeatureBits())); + return (Read ? Reg.Readable : Reg.Writeable) && + Reg.haveFeatures(STI.getFeatureBits()); } -// Looks up a system register either by encoding or by name. Some system +// Looks up a system register either by encoding. Some system // registers share the same encoding between different architectures, -// therefore a tablegen lookup by encoding will return an entry regardless -// of the register's predication on a specific subtarget feature. To work -// around this problem we keep an alternative name for such registers and -// look them up by that name if the first lookup was unsuccessful. +// to work around this tablegen will return a range of registers with the same +// encodings. We need to check each register in the range to see if it valid. static const AArch64SysReg::SysReg *lookupSysReg(unsigned Val, bool Read, const MCSubtargetInfo &STI) { - const AArch64SysReg::SysReg *Reg = AArch64SysReg::lookupSysRegByEncoding(Val); - - if (Reg && !isValidSysReg(Reg, Read, STI)) - Reg = AArch64SysReg::lookupSysRegByName(Reg->AltName); + auto Range = AArch64SysReg::lookupSysRegByEncoding(Val); + for (auto &Reg : Range) { + if (isValidSysReg(Reg, Read, STI)) + return &Reg; + } - return Reg; + return nullptr; } void AArch64InstPrinter::printMRSSystemRegister(const MCInst *MI, unsigned OpNo, @@ -1917,7 +1916,7 @@ void AArch64InstPrinter::printMRSSystemRegister(const MCInst *MI, unsigned OpNo, const AArch64SysReg::SysReg *Reg = lookupSysReg(Val, true /*Read*/, STI); - if (isValidSysReg(Reg, true /*Read*/, STI)) + if (Reg) O << Reg->Name; else O << AArch64SysReg::genericRegisterString(Val); @@ -1944,7 +1943,7 @@ void AArch64InstPrinter::printMSRSystemRegister(const MCInst *MI, unsigned OpNo, const AArch64SysReg::SysReg *Reg = lookupSysReg(Val, false /*Read*/, STI); - if (isValidSysReg(Reg, false /*Read*/, STI)) + if (Reg) O << Reg->Name; else O << AArch64SysReg::genericRegisterString(Val); diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h index 94bba4e..b8d3236 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -564,11 +564,10 @@ LLVM_DECLARE_ENUM_AS_BITMASK(TailFoldingOpts, /* LargestValue */ (long)TailFoldingOpts::Reverse); namespace AArch64ExactFPImm { - struct ExactFPImm { - const char *Name; - int Enum; - const char *Repr; - }; +struct ExactFPImm { + int Enum; + const char *Repr; +}; #define GET_ExactFPImmValues_DECL #define GET_ExactFPImmsList_DECL #include "AArch64GenSystemOperands.inc" @@ -602,7 +601,6 @@ namespace AArch64PSBHint { namespace AArch64PHint { struct PHint { const char *Name; - const char *AltName; unsigned Encoding; FeatureBitset FeaturesRequired; @@ -720,7 +718,6 @@ AArch64StringToVectorLayout(StringRef LayoutStr) { namespace AArch64SysReg { struct SysReg { const char Name[32]; - const char AltName[32]; unsigned Encoding; bool Readable; bool Writeable; @@ -736,9 +733,6 @@ namespace AArch64SysReg { #define GET_SysRegValues_DECL #include "AArch64GenSystemOperands.inc" - const SysReg *lookupSysRegByName(StringRef); - const SysReg *lookupSysRegByEncoding(uint16_t); - uint32_t parseGenericRegister(StringRef Name); std::string genericRegisterString(uint32_t Bits); } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp index e844904..0f97988 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp @@ -1523,7 +1523,8 @@ Value *AMDGPUCodeGenPrepareImpl::shrinkDivRem64(IRBuilder<> &Builder, bool IsDiv = Opc == Instruction::SDiv || Opc == Instruction::UDiv; bool IsSigned = Opc == Instruction::SDiv || Opc == Instruction::SRem; - int NumDivBits = getDivNumBits(I, Num, Den, 32, IsSigned); + unsigned BitWidth = Num->getType()->getScalarSizeInBits(); + int NumDivBits = getDivNumBits(I, Num, Den, BitWidth - 32, IsSigned); if (NumDivBits == -1) return nullptr; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp index d9eaf82..27e9018 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp @@ -1997,7 +1997,7 @@ bool AMDGPUDAGToDAGISel::SelectScratchSVAddr(SDNode *N, SDValue Addr, if (checkFlatScratchSVSSwizzleBug(VAddr, SAddr, ImmOffset)) return false; SAddr = SelectSAddrFI(CurDAG, SAddr); - Offset = CurDAG->getTargetConstant(ImmOffset, SDLoc(), MVT::i32); + Offset = CurDAG->getSignedTargetConstant(ImmOffset, SDLoc(), MVT::i32); return true; } diff --git a/llvm/lib/Target/AMDGPU/BUFInstructions.td b/llvm/lib/Target/AMDGPU/BUFInstructions.td index 88205ea..f2686bd 100644 --- a/llvm/lib/Target/AMDGPU/BUFInstructions.td +++ b/llvm/lib/Target/AMDGPU/BUFInstructions.td @@ -680,7 +680,7 @@ multiclass MUBUF_Pseudo_Stores<string opName, ValueType store_vt = i32> { class MUBUF_Pseudo_Store_Lds<string opName> : MUBUF_Pseudo<opName, (outs), - (ins SReg_128:$srsrc, SCSrc_b32:$soffset, Offset:$offset, CPol:$cpol, i1imm:$swz), + (ins SReg_128_XNULL:$srsrc, SCSrc_b32:$soffset, Offset:$offset, CPol:$cpol, i1imm:$swz), " $srsrc, $soffset$offset lds$cpol"> { let LGKM_CNT = 1; let mayLoad = 1; diff --git a/llvm/lib/Target/AMDGPU/MIMGInstructions.td b/llvm/lib/Target/AMDGPU/MIMGInstructions.td index 3c7627f..1b94d6c 100644 --- a/llvm/lib/Target/AMDGPU/MIMGInstructions.td +++ b/llvm/lib/Target/AMDGPU/MIMGInstructions.td @@ -1524,7 +1524,7 @@ class MIMG_IntersectRay_Helper<bit Is64, bit IsA16> { class MIMG_IntersectRay_gfx10<mimgopc op, string opcode, RegisterClass AddrRC> : MIMG_gfx10<op.GFX10M, (outs VReg_128:$vdata), "GFX10"> { - let InOperandList = (ins AddrRC:$vaddr0, SReg_128:$srsrc, A16:$a16); + let InOperandList = (ins AddrRC:$vaddr0, SReg_128_XNULL:$srsrc, A16:$a16); let AsmString = opcode#" $vdata, $vaddr0, $srsrc$a16"; let nsa = 0; @@ -1532,13 +1532,13 @@ class MIMG_IntersectRay_gfx10<mimgopc op, string opcode, RegisterClass AddrRC> class MIMG_IntersectRay_nsa_gfx10<mimgopc op, string opcode, int num_addrs> : MIMG_nsa_gfx10<op.GFX10M, (outs VReg_128:$vdata), num_addrs, "GFX10"> { - let InOperandList = !con(nsah.AddrIns, (ins SReg_128:$srsrc, A16:$a16)); + let InOperandList = !con(nsah.AddrIns, (ins SReg_128_XNULL:$srsrc, A16:$a16)); let AsmString = opcode#" $vdata, "#nsah.AddrAsm#", $srsrc$a16"; } class MIMG_IntersectRay_gfx11<mimgopc op, string opcode, RegisterClass AddrRC> : MIMG_gfx11<op.GFX11, (outs VReg_128:$vdata), "GFX11"> { - let InOperandList = (ins AddrRC:$vaddr0, SReg_128:$srsrc, A16:$a16); + let InOperandList = (ins AddrRC:$vaddr0, SReg_128_XNULL:$srsrc, A16:$a16); let AsmString = opcode#" $vdata, $vaddr0, $srsrc$a16"; let nsa = 0; @@ -1548,7 +1548,7 @@ class MIMG_IntersectRay_nsa_gfx11<mimgopc op, string opcode, int num_addrs, list<RegisterClass> addr_types> : MIMG_nsa_gfx11<op.GFX11, (outs VReg_128:$vdata), num_addrs, "GFX11", addr_types> { - let InOperandList = !con(nsah.AddrIns, (ins SReg_128:$srsrc, A16:$a16)); + let InOperandList = !con(nsah.AddrIns, (ins SReg_128_XNULL:$srsrc, A16:$a16)); let AsmString = opcode#" $vdata, "#nsah.AddrAsm#", $srsrc$a16"; } @@ -1556,7 +1556,7 @@ class VIMAGE_IntersectRay_gfx12<mimgopc op, string opcode, int num_addrs, list<RegisterClass> addr_types> : VIMAGE_gfx12<op.GFX12, (outs VReg_128:$vdata), num_addrs, "GFX12", addr_types> { - let InOperandList = !con(nsah.AddrIns, (ins SReg_128:$rsrc, A16:$a16)); + let InOperandList = !con(nsah.AddrIns, (ins SReg_128_XNULL:$rsrc, A16:$a16)); let AsmString = opcode#" $vdata, "#nsah.AddrAsm#", $rsrc$a16"; } diff --git a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp index c2199fd..2bc1913 100644 --- a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp +++ b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp @@ -1096,21 +1096,8 @@ void SIFoldOperandsImpl::foldOperand( B.addImm(Defs[I].second); } LLVM_DEBUG(dbgs() << "Folded " << *UseMI); - return; } - if (Size != 4) - return; - - Register Reg0 = UseMI->getOperand(0).getReg(); - Register Reg1 = UseMI->getOperand(1).getReg(); - if (TRI->isAGPR(*MRI, Reg0) && TRI->isVGPR(*MRI, Reg1)) - UseMI->setDesc(TII->get(AMDGPU::V_ACCVGPR_WRITE_B32_e64)); - else if (TRI->isVGPR(*MRI, Reg0) && TRI->isAGPR(*MRI, Reg1)) - UseMI->setDesc(TII->get(AMDGPU::V_ACCVGPR_READ_B32_e64)); - else if (ST->hasGFX90AInsts() && TRI->isAGPR(*MRI, Reg0) && - TRI->isAGPR(*MRI, Reg1)) - UseMI->setDesc(TII->get(AMDGPU::V_ACCVGPR_MOV_B32)); return; } diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index b3cfa39..0ac84f4 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -13985,6 +13985,43 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N, return Accum; } +SDValue +SITargetLowering::foldAddSub64WithZeroLowBitsTo32(SDNode *N, + DAGCombinerInfo &DCI) const { + SDValue RHS = N->getOperand(1); + auto *CRHS = dyn_cast<ConstantSDNode>(RHS); + if (!CRHS) + return SDValue(); + + // TODO: Worth using computeKnownBits? Maybe expensive since it's so + // common. + uint64_t Val = CRHS->getZExtValue(); + if (countr_zero(Val) >= 32) { + SelectionDAG &DAG = DCI.DAG; + SDLoc SL(N); + SDValue LHS = N->getOperand(0); + + // Avoid carry machinery if we know the low half of the add does not + // contribute to the final result. + // + // add i64:x, K if computeTrailingZeros(K) >= 32 + // => build_pair (add x.hi, K.hi), x.lo + + // Breaking the 64-bit add here with this strange constant is unlikely + // to interfere with addressing mode patterns. + + SDValue Hi = getHiHalf64(LHS, DAG); + SDValue ConstHi32 = DAG.getConstant(Hi_32(Val), SL, MVT::i32); + SDValue AddHi = + DAG.getNode(N->getOpcode(), SL, MVT::i32, Hi, ConstHi32, N->getFlags()); + + SDValue Lo = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS); + return DAG.getNode(ISD::BUILD_PAIR, SL, MVT::i64, Lo, AddHi); + } + + return SDValue(); +} + // Collect the ultimate src of each of the mul node's operands, and confirm // each operand is 8 bytes. static std::optional<ByteProvider<SDValue>> @@ -14261,6 +14298,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, return V; } + if (VT == MVT::i64) { + if (SDValue Folded = foldAddSub64WithZeroLowBitsTo32(N, DCI)) + return Folded; + } + if ((isMul(LHS) || isMul(RHS)) && Subtarget->hasDot7Insts() && (Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) { SDValue TempNode(N, 0); @@ -14446,6 +14488,11 @@ SDValue SITargetLowering::performSubCombine(SDNode *N, SelectionDAG &DAG = DCI.DAG; EVT VT = N->getValueType(0); + if (VT == MVT::i64) { + if (SDValue Folded = foldAddSub64WithZeroLowBitsTo32(N, DCI)) + return Folded; + } + if (VT != MVT::i32) return SDValue(); diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h index f4641e7..299c8f5 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.h +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h @@ -212,6 +212,9 @@ private: unsigned getFusedOpcode(const SelectionDAG &DAG, const SDNode *N0, const SDNode *N1) const; SDValue tryFoldToMad64_32(SDNode *N, DAGCombinerInfo &DCI) const; + SDValue foldAddSub64WithZeroLowBitsTo32(SDNode *N, + DAGCombinerInfo &DCI) const; + SDValue performAddCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performAddCarrySubCarryCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performSubCombine(SDNode *N, DAGCombinerInfo &DCI) const; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index 692e286..e6f333f 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -6890,9 +6890,8 @@ SIInstrInfo::legalizeOperands(MachineInstr &MI, AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::srsrc); if (RsrcIdx != -1) { MachineOperand *Rsrc = &MI.getOperand(RsrcIdx); - if (Rsrc->isReg() && !RI.isSGPRClass(MRI.getRegClass(Rsrc->getReg()))) { + if (Rsrc->isReg() && !RI.isSGPRReg(MRI, Rsrc->getReg())) isRsrcLegal = false; - } } // The operands are legal. diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td index e388efe..ee83dff 100644 --- a/llvm/lib/Target/AMDGPU/SIInstructions.td +++ b/llvm/lib/Target/AMDGPU/SIInstructions.td @@ -3055,7 +3055,7 @@ def : GCNPat< (V_BFREV_B32_e64 (i32 (EXTRACT_SUBREG VReg_64:$a, sub0))), sub1)>; // If fcanonicalize's operand is implicitly canonicalized, we only need a copy. -let AddedComplexity = 1000 in { +let AddedComplexity = 8 in { foreach vt = [f16, v2f16, f32, v2f32, f64] in { def : GCNPat< (fcanonicalize (vt is_canonicalized:$src)), @@ -3710,12 +3710,15 @@ def : IntMinMaxPat<V_MAXMIN_U32_e64, umin, umax_oneuse>; def : IntMinMaxPat<V_MINMAX_U32_e64, umax, umin_oneuse>; def : FPMinMaxPat<V_MINMAX_F32_e64, f32, fmaxnum_like, fminnum_like_oneuse>; def : FPMinMaxPat<V_MAXMIN_F32_e64, f32, fminnum_like, fmaxnum_like_oneuse>; -def : FPMinMaxPat<V_MINMAX_F16_e64, f16, fmaxnum_like, fminnum_like_oneuse>; -def : FPMinMaxPat<V_MAXMIN_F16_e64, f16, fminnum_like, fmaxnum_like_oneuse>; def : FPMinCanonMaxPat<V_MINMAX_F32_e64, f32, fmaxnum_like, fminnum_like_oneuse>; def : FPMinCanonMaxPat<V_MAXMIN_F32_e64, f32, fminnum_like, fmaxnum_like_oneuse>; -def : FPMinCanonMaxPat<V_MINMAX_F16_e64, f16, fmaxnum_like, fminnum_like_oneuse>; -def : FPMinCanonMaxPat<V_MAXMIN_F16_e64, f16, fminnum_like, fmaxnum_like_oneuse>; +} + +let True16Predicate = UseFakeTrue16Insts in { +def : FPMinMaxPat<V_MINMAX_F16_fake16_e64, f16, fmaxnum_like, fminnum_like_oneuse>; +def : FPMinMaxPat<V_MAXMIN_F16_fake16_e64, f16, fminnum_like, fmaxnum_like_oneuse>; +def : FPMinCanonMaxPat<V_MINMAX_F16_fake16_e64, f16, fmaxnum_like, fminnum_like_oneuse>; +def : FPMinCanonMaxPat<V_MAXMIN_F16_fake16_e64, f16, fminnum_like, fmaxnum_like_oneuse>; } let SubtargetPredicate = isGFX9Plus in { @@ -3723,6 +3726,10 @@ let True16Predicate = NotHasTrue16BitInsts in { defm : Int16Med3Pat<V_MED3_I16_e64, smin, smax, VSrc_b16>; defm : Int16Med3Pat<V_MED3_U16_e64, umin, umax, VSrc_b16>; } +let True16Predicate = UseRealTrue16Insts in { + defm : Int16Med3Pat<V_MED3_I16_t16_e64, smin, smax, VSrcT_b16>; + defm : Int16Med3Pat<V_MED3_U16_t16_e64, umin, umax, VSrcT_b16>; +} let True16Predicate = UseFakeTrue16Insts in { defm : Int16Med3Pat<V_MED3_I16_fake16_e64, smin, smax, VSrc_b16>; defm : Int16Med3Pat<V_MED3_U16_fake16_e64, umin, umax, VSrc_b16>; diff --git a/llvm/lib/Target/AMDGPU/SMInstructions.td b/llvm/lib/Target/AMDGPU/SMInstructions.td index 60e4ce9..37dcc100 100644 --- a/llvm/lib/Target/AMDGPU/SMInstructions.td +++ b/llvm/lib/Target/AMDGPU/SMInstructions.td @@ -341,10 +341,10 @@ let SubtargetPredicate = HasScalarDwordx3Loads in defm S_BUFFER_LOAD_DWORDX4 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_128>; defm S_BUFFER_LOAD_DWORDX8 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_256>; defm S_BUFFER_LOAD_DWORDX16 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_512>; -defm S_BUFFER_LOAD_I8 : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_LOAD_U8 : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_LOAD_I16 : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; -defm S_BUFFER_LOAD_U16 : SM_Pseudo_Loads <SReg_128, SReg_32_XM0_XEXEC>; +defm S_BUFFER_LOAD_I8 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_LOAD_U8 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_LOAD_I16 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; +defm S_BUFFER_LOAD_U16 : SM_Pseudo_Loads <SReg_128_XNULL, SReg_32_XM0_XEXEC>; } let SubtargetPredicate = HasScalarStores in { @@ -375,7 +375,7 @@ def S_DCACHE_WB_VOL : SM_Inval_Pseudo <"s_dcache_wb_vol", int_amdgcn_s_dcache_wb defm S_ATC_PROBE : SM_Pseudo_Probe <SReg_64>; let is_buffer = 1 in { -defm S_ATC_PROBE_BUFFER : SM_Pseudo_Probe <SReg_128>; +defm S_ATC_PROBE_BUFFER : SM_Pseudo_Probe <SReg_128_XNULL>; } } // SubtargetPredicate = isGFX8Plus @@ -470,7 +470,7 @@ def S_PREFETCH_INST : SM_Prefetch_Pseudo <"s_prefetch_inst", SReg_64, 1>; def S_PREFETCH_INST_PC_REL : SM_Prefetch_Pseudo <"s_prefetch_inst_pc_rel", SReg_64, 0>; def S_PREFETCH_DATA : SM_Prefetch_Pseudo <"s_prefetch_data", SReg_64, 1>; def S_PREFETCH_DATA_PC_REL : SM_Prefetch_Pseudo <"s_prefetch_data_pc_rel", SReg_64, 0>; -def S_BUFFER_PREFETCH_DATA : SM_Prefetch_Pseudo <"s_buffer_prefetch_data", SReg_128, 1> { +def S_BUFFER_PREFETCH_DATA : SM_Prefetch_Pseudo <"s_buffer_prefetch_data", SReg_128_XNULL, 1> { let is_buffer = 1; } } // end let SubtargetPredicate = isGFX12Plus diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index cef1f20..24a2eed 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -1374,8 +1374,8 @@ class VOP3_DOT_Profile_fake16<VOPProfile P, VOP3Features Features = VOP3_REGULAR let SubtargetPredicate = isGFX11Plus in { defm V_MAXMIN_F32 : VOP3Inst<"v_maxmin_f32", VOP3_Profile<VOP_F32_F32_F32_F32>>; defm V_MINMAX_F32 : VOP3Inst<"v_minmax_f32", VOP3_Profile<VOP_F32_F32_F32_F32>>; - defm V_MAXMIN_F16 : VOP3Inst<"v_maxmin_f16", VOP3_Profile<VOP_F16_F16_F16_F16>>; - defm V_MINMAX_F16 : VOP3Inst<"v_minmax_f16", VOP3_Profile<VOP_F16_F16_F16_F16>>; + defm V_MAXMIN_F16 : VOP3Inst_t16<"v_maxmin_f16", VOP_F16_F16_F16_F16>; + defm V_MINMAX_F16 : VOP3Inst_t16<"v_minmax_f16", VOP_F16_F16_F16_F16>; defm V_MAXMIN_U32 : VOP3Inst<"v_maxmin_u32", VOP3_Profile<VOP_I32_I32_I32_I32>>; defm V_MINMAX_U32 : VOP3Inst<"v_minmax_u32", VOP3_Profile<VOP_I32_I32_I32_I32>>; defm V_MAXMIN_I32 : VOP3Inst<"v_maxmin_i32", VOP3_Profile<VOP_I32_I32_I32_I32>>; @@ -1588,8 +1588,8 @@ defm V_MED3_NUM_F32 : VOP3_Realtriple_with_name_gfx12<0x231, "V_MED3_F32", defm V_MED3_NUM_F16 : VOP3_Realtriple_t16_and_fake16_gfx12<0x232, "v_med3_num_f16", "V_MED3_F16", "v_med3_f16">; defm V_MINMAX_NUM_F32 : VOP3_Realtriple_with_name_gfx12<0x268, "V_MINMAX_F32", "v_minmax_num_f32">; defm V_MAXMIN_NUM_F32 : VOP3_Realtriple_with_name_gfx12<0x269, "V_MAXMIN_F32", "v_maxmin_num_f32">; -defm V_MINMAX_NUM_F16 : VOP3_Realtriple_with_name_gfx12<0x26a, "V_MINMAX_F16", "v_minmax_num_f16">; -defm V_MAXMIN_NUM_F16 : VOP3_Realtriple_with_name_gfx12<0x26b, "V_MAXMIN_F16", "v_maxmin_num_f16">; +defm V_MINMAX_NUM_F16 : VOP3_Realtriple_t16_and_fake16_gfx12<0x26a, "v_minmax_num_f16", "V_MINMAX_F16", "v_minmax_f16">; +defm V_MAXMIN_NUM_F16 : VOP3_Realtriple_t16_and_fake16_gfx12<0x26b, "v_maxmin_num_f16", "V_MAXMIN_F16", "v_maxmin_f16">; defm V_MINIMUMMAXIMUM_F32 : VOP3Only_Realtriple_gfx12<0x26c>; defm V_MAXIMUMMINIMUM_F32 : VOP3Only_Realtriple_gfx12<0x26d>; defm V_MINIMUMMAXIMUM_F16 : VOP3Only_Realtriple_t16_gfx12<0x26e>; @@ -1730,8 +1730,8 @@ defm V_PERMLANE16_B32 : VOP3_Real_Base_gfx11_gfx12<0x25b>; defm V_PERMLANEX16_B32 : VOP3_Real_Base_gfx11_gfx12<0x25c>; defm V_MAXMIN_F32 : VOP3_Realtriple_gfx11<0x25e>; defm V_MINMAX_F32 : VOP3_Realtriple_gfx11<0x25f>; -defm V_MAXMIN_F16 : VOP3_Realtriple_gfx11<0x260>; -defm V_MINMAX_F16 : VOP3_Realtriple_gfx11<0x261>; +defm V_MAXMIN_F16 : VOP3_Realtriple_t16_and_fake16_gfx11<0x260, "v_maxmin_f16">; +defm V_MINMAX_F16 : VOP3_Realtriple_t16_and_fake16_gfx11<0x261, "v_minmax_f16">; defm V_MAXMIN_U32 : VOP3_Realtriple_gfx11_gfx12<0x262>; defm V_MINMAX_U32 : VOP3_Realtriple_gfx11_gfx12<0x263>; defm V_MAXMIN_I32 : VOP3_Realtriple_gfx11_gfx12<0x264>; diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td index d236907..930ed9a 100644 --- a/llvm/lib/Target/AMDGPU/VOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td @@ -1909,8 +1909,8 @@ multiclass VOP3_Realtriple_t16_gfx11<bits<10> op, string asmName, string opName multiclass VOP3_Realtriple_t16_and_fake16_gfx11<bits<10> op, string asmName, string opName = NAME, string pseudo_mnemonic = "", bit isSingle = 0> { - defm _t16: VOP3_Realtriple_t16_gfx11<op, opName#"_t16", asmName, pseudo_mnemonic, isSingle>; - defm _fake16: VOP3_Realtriple_t16_gfx11<op, opName#"_fake16", asmName, pseudo_mnemonic, isSingle>; + defm _t16: VOP3_Realtriple_t16_gfx11<op, asmName, opName#"_t16", pseudo_mnemonic, isSingle>; + defm _fake16: VOP3_Realtriple_t16_gfx11<op, asmName, opName#"_fake16", pseudo_mnemonic, isSingle>; } multiclass VOP3Only_Realtriple_t16_gfx11<bits<10> op, string asmName, diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td index c67177c..009b60c 100644 --- a/llvm/lib/Target/ARM/ARMInstrInfo.td +++ b/llvm/lib/Target/ARM/ARMInstrInfo.td @@ -3320,7 +3320,7 @@ def STRH_preidx: ARMPseudoInst<(outs GPR:$Rn_wb), } - +let mayStore = 1, hasSideEffects = 0 in { def STRH_PRE : AI3ldstidx<0b1011, 0, 1, (outs GPR:$Rn_wb), (ins GPR:$Rt, addrmode3_pre:$addr), IndexModePre, StMiscFrm, IIC_iStore_bh_ru, @@ -3352,6 +3352,7 @@ def STRH_POST : AI3ldstidx<0b1011, 0, 0, (outs GPR:$Rn_wb), let Inst{3-0} = offset{3-0}; // imm3_0/Rm let DecoderMethod = "DecodeAddrMode3Instruction"; } +} // mayStore = 1, hasSideEffects = 0 let mayStore = 1, hasSideEffects = 0, hasExtraSrcRegAllocReq = 1 in { def STRD_PRE : AI3ldstidx<0b1111, 0, 1, (outs GPR:$Rn_wb), diff --git a/llvm/lib/Target/ARM/ARMSystemRegister.td b/llvm/lib/Target/ARM/ARMSystemRegister.td index c03db15..3afc410 100644 --- a/llvm/lib/Target/ARM/ARMSystemRegister.td +++ b/llvm/lib/Target/ARM/ARMSystemRegister.td @@ -19,17 +19,13 @@ class MClassSysReg<bits<1> UniqMask1, bits<1> UniqMask2, bits<1> UniqMask3, bits<12> Enc12, - string name> : SearchableTable { - let SearchableFields = ["Name", "M1Encoding12", "M2M3Encoding8", "Encoding"]; + string name> { string Name; bits<13> M1Encoding12; bits<10> M2M3Encoding8; bits<12> Encoding; let Name = name; - let EnumValueField = "M1Encoding12"; - let EnumValueField = "M2M3Encoding8"; - let EnumValueField = "Encoding"; let M1Encoding12{12} = UniqMask1; let M1Encoding12{11-00} = Enc12; @@ -41,6 +37,27 @@ class MClassSysReg<bits<1> UniqMask1, code Requires = [{ {} }]; } +def MClassSysRegsList : GenericTable { + let FilterClass = "MClassSysReg"; + let Fields = ["Name", "M1Encoding12", "M2M3Encoding8", "Encoding", + "Requires"]; +} + +def lookupMClassSysRegByName : SearchIndex { + let Table = MClassSysRegsList; + let Key = ["Name"]; +} + +def lookupMClassSysRegByM1Encoding12 : SearchIndex { + let Table = MClassSysRegsList; + let Key = ["M1Encoding12"]; +} + +def lookupMClassSysRegByM2M3Encoding8 : SearchIndex { + let Table = MClassSysRegsList; + let Key = ["M2M3Encoding8"]; +} + // [|i|e|x]apsr_nzcvq has alias [|i|e|x]apsr. // Mask1 Mask2 Mask3 Enc12, Name let Requires = [{ {ARM::FeatureDSP} }] in { @@ -127,15 +144,29 @@ def : MClassSysReg<0, 0, 1, 0x8a7, "pac_key_u_3_ns">; // Banked Registers // -class BankedReg<string name, bits<8> enc> - : SearchableTable { +class BankedReg<string name, bits<8> enc> { string Name; bits<8> Encoding; let Name = name; let Encoding = enc; - let SearchableFields = ["Name", "Encoding"]; } +def BankedRegsList : GenericTable { + let FilterClass = "BankedReg"; + let Fields = ["Name", "Encoding"]; +} + +def lookupBankedRegByName : SearchIndex { + let Table = BankedRegsList; + let Key = ["Name"]; +} + +def lookupBankedRegByEncoding : SearchIndex { + let Table = BankedRegsList; + let Key = ["Encoding"]; +} + + // The values here come from B9.2.3 of the ARM ARM, where bits 4-0 are SysM // and bit 5 is R. def : BankedReg<"r8_usr", 0x00>; diff --git a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp index 494c67d..e76a70b 100644 --- a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp +++ b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp @@ -62,13 +62,13 @@ const MClassSysReg *lookupMClassSysRegBy8bitSYSmValue(unsigned SYSm) { return ARMSysReg::lookupMClassSysRegByM2M3Encoding8((1<<8)|(SYSm & 0xFF)); } -#define GET_MCLASSSYSREG_IMPL +#define GET_MClassSysRegsList_IMPL #include "ARMGenSystemRegister.inc" } // end namespace ARMSysReg namespace ARMBankedReg { -#define GET_BANKEDREG_IMPL +#define GET_BankedRegsList_IMPL #include "ARMGenSystemRegister.inc" } // end namespce ARMSysReg } // end namespace llvm diff --git a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h index 5562572..dc4f811 100644 --- a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h +++ b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h @@ -206,8 +206,8 @@ namespace ARMSysReg { } }; - #define GET_MCLASSSYSREG_DECL - #include "ARMGenSystemRegister.inc" +#define GET_MClassSysRegsList_DECL +#include "ARMGenSystemRegister.inc" // lookup system register using 12-bit SYSm value. // Note: the search is uniqued using M1 mask @@ -228,8 +228,8 @@ namespace ARMBankedReg { const char *Name; uint16_t Encoding; }; - #define GET_BANKEDREG_DECL - #include "ARMGenSystemRegister.inc" +#define GET_BankedRegsList_DECL +#include "ARMGenSystemRegister.inc" } // end namespace ARMBankedReg } // end namespace llvm diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 5d865a3..62b5b70 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -42,8 +42,10 @@ def FloatTy : DXILOpParamType; def DoubleTy : DXILOpParamType; def ResRetHalfTy : DXILOpParamType; def ResRetFloatTy : DXILOpParamType; +def ResRetDoubleTy : DXILOpParamType; def ResRetInt16Ty : DXILOpParamType; def ResRetInt32Ty : DXILOpParamType; +def ResRetInt64Ty : DXILOpParamType; def HandleTy : DXILOpParamType; def ResBindTy : DXILOpParamType; def ResPropsTy : DXILOpParamType; @@ -890,6 +892,23 @@ def SplitDouble : DXILOp<102, splitDouble> { let attributes = [Attributes<DXIL1_0, [ReadNone]>]; } +def RawBufferLoad : DXILOp<139, rawBufferLoad> { + let Doc = "reads from a raw buffer and structured buffer"; + // Handle, Coord0, Coord1, Mask, Alignment + let arguments = [HandleTy, Int32Ty, Int32Ty, Int8Ty, Int32Ty]; + let result = OverloadTy; + let overloads = [ + Overloads<DXIL1_2, + [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>, + Overloads<DXIL1_3, + [ + ResRetHalfTy, ResRetFloatTy, ResRetDoubleTy, ResRetInt16Ty, + ResRetInt32Ty, ResRetInt64Ty + ]> + ]; + let stages = [Stages<DXIL1_2, [all_stages]>]; +} + def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> { let Doc = "signed dot product of 4 x i8 vectors packed into i32, with " "accumulate to i32"; diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 5d5bb3e..9f88ccd 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -263,10 +263,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, return getResRetType(Type::getHalfTy(Ctx)); case OpParamType::ResRetFloatTy: return getResRetType(Type::getFloatTy(Ctx)); + case OpParamType::ResRetDoubleTy: + return getResRetType(Type::getDoubleTy(Ctx)); case OpParamType::ResRetInt16Ty: return getResRetType(Type::getInt16Ty(Ctx)); case OpParamType::ResRetInt32Ty: return getResRetType(Type::getInt32Ty(Ctx)); + case OpParamType::ResRetInt64Ty: + return getResRetType(Type::getInt64Ty(Ctx)); case OpParamType::HandleTy: return getHandleType(Ctx); case OpParamType::ResBindTy: diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 4e01dd1..f43815b 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -415,8 +415,16 @@ public: } } - OldResult = cast<Instruction>( - IRB.CreateExtractValue(Op, 0, OldResult->getName())); + if (OldResult->use_empty()) { + // Only the check bit was used, so we're done here. + OldResult->eraseFromParent(); + return Error::success(); + } + + assert(OldResult->hasOneUse() && + isa<ExtractValueInst>(*OldResult->user_begin()) && + "Expected only use to be extract of first element"); + OldResult = cast<Instruction>(*OldResult->user_begin()); OldTy = ST->getElementType(0); } @@ -534,6 +542,48 @@ public: }); } + [[nodiscard]] bool lowerRawBufferLoad(Function &F) { + Triple TT(Triple(M.getTargetTriple())); + VersionTuple DXILVersion = TT.getDXILVersion(); + const DataLayout &DL = F.getDataLayout(); + IRBuilder<> &IRB = OpBuilder.getIRB(); + Type *Int8Ty = IRB.getInt8Ty(); + Type *Int32Ty = IRB.getInt32Ty(); + + return replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + Type *OldTy = cast<StructType>(CI->getType())->getElementType(0); + Type *ScalarTy = OldTy->getScalarType(); + Type *NewRetTy = OpBuilder.getResRetType(ScalarTy); + + Value *Handle = + createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); + Value *Index0 = CI->getArgOperand(1); + Value *Index1 = CI->getArgOperand(2); + uint64_t NumElements = + DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy); + Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements)); + Value *Align = + ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()); + + Expected<CallInst *> OpCall = + DXILVersion >= VersionTuple(1, 2) + ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad, + {Handle, Index0, Index1, Mask, Align}, + CI->getName(), NewRetTy) + : OpBuilder.tryCreateOp(OpCode::BufferLoad, + {Handle, Index0, Index1}, CI->getName(), + NewRetTy); + if (Error E = OpCall.takeError()) + return E; + if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true)) + return E; + + return Error::success(); + }); + } + [[nodiscard]] bool lowerUpdateCounter(Function &F) { IRBuilder<> &IRB = OpBuilder.getIRB(); Type *Int32Ty = IRB.getInt32Ty(); @@ -723,14 +773,14 @@ public: HasErrors |= lowerGetPointer(F); break; case Intrinsic::dx_resource_load_typedbuffer: - HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false); - break; - case Intrinsic::dx_resource_loadchecked_typedbuffer: HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true); break; case Intrinsic::dx_resource_store_typedbuffer: HasErrors |= lowerTypedBufferStore(F); break; + case Intrinsic::dx_resource_load_rawbuffer: + HasErrors |= lowerRawBufferLoad(F); + break; case Intrinsic::dx_resource_updatecounter: HasErrors |= lowerUpdateCounter(F); break; diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp index 1ff8f09..8376249 100644 --- a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp +++ b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp @@ -30,6 +30,9 @@ static void replaceTypedBufferAccess(IntrinsicInst *II, "Unexpected typed buffer type"); Type *ContainedType = HandleType->getTypeParameter(0); + Type *LoadType = + StructType::get(ContainedType, Type::getInt1Ty(II->getContext())); + // We need the size of an element in bytes so that we can calculate the offset // in elements given a total offset in bytes later. Type *ScalarType = ContainedType->getScalarType(); @@ -81,13 +84,15 @@ static void replaceTypedBufferAccess(IntrinsicInst *II, // We're storing a scalar, so we need to load the current value and only // replace the relevant part. auto *Load = Builder.CreateIntrinsic( - ContainedType, Intrinsic::dx_resource_load_typedbuffer, + LoadType, Intrinsic::dx_resource_load_typedbuffer, {II->getOperand(0), II->getOperand(1)}); + auto *Struct = Builder.CreateExtractValue(Load, {0}); + // If we have an offset from seeing a GEP earlier, use it. Value *IndexOp = Current.Index ? Current.Index : ConstantInt::get(Builder.getInt32Ty(), 0); - V = Builder.CreateInsertElement(Load, V, IndexOp); + V = Builder.CreateInsertElement(Struct, V, IndexOp); } else { llvm_unreachable("Store to typed resource has invalid type"); } @@ -101,8 +106,10 @@ static void replaceTypedBufferAccess(IntrinsicInst *II, } else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) { IRBuilder<> Builder(LI); Value *V = Builder.CreateIntrinsic( - ContainedType, Intrinsic::dx_resource_load_typedbuffer, + LoadType, Intrinsic::dx_resource_load_typedbuffer, {II->getOperand(0), II->getOperand(1)}); + V = Builder.CreateExtractValue(V, {0}); + if (Current.Index) V = Builder.CreateExtractElement(V, Current.Index); diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index ad079f4..5afe6b2 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -15,14 +15,12 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/Analysis/DXILResource.h" -#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" @@ -302,39 +300,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx); } -// TODO: We might need to refactor this to be more generic, -// in case we need more metadata to be replaced. -static void translateBranchMetadata(Module &M) { - for (Function &F : M) { - for (BasicBlock &BB : F) { - Instruction *BBTerminatorInst = BB.getTerminator(); - - MDNode *HlslControlFlowMD = - BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); - - if (!HlslControlFlowMD) - continue; - - assert(HlslControlFlowMD->getNumOperands() == 2 && - "invalid operands for hlsl.controlflow.hint"); - - MDBuilder MDHelper(M.getContext()); - ConstantInt *Op1 = - mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1)); - - SmallVector<llvm::Metadata *, 2> Vals( - ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"), - MDHelper.createConstant(Op1)}); - - MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals); - - BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode); - BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr); - } - F.clearMetadata(); - } -} - static void translateMetadata(Module &M, DXILBindingMap &DBM, DXILResourceTypeMap &DRTM, const Resources &MDResources, @@ -407,7 +372,6 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M, const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M); translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI); - translateBranchMetadata(M); return PreservedAnalyses::all(); } @@ -445,7 +409,6 @@ public: getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI); - translateBranchMetadata(M); return true; } }; diff --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp index 45aadac..be68d46 100644 --- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp +++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp @@ -749,8 +749,8 @@ uint64_t DXILBitcodeWriter::getOptimizationFlags(const Value *V) { if (PEO->isExact()) Flags |= 1 << bitc::PEO_EXACT; } else if (const auto *FPMO = dyn_cast<FPMathOperator>(V)) { - if (FPMO->hasAllowReassoc()) - Flags |= bitc::AllowReassoc; + if (FPMO->hasAllowReassoc() || FPMO->hasAllowContract()) + Flags |= bitc::UnsafeAlgebra; if (FPMO->hasNoNaNs()) Flags |= bitc::NoNaNs; if (FPMO->hasNoInfs()) @@ -759,10 +759,6 @@ uint64_t DXILBitcodeWriter::getOptimizationFlags(const Value *V) { Flags |= bitc::NoSignedZeros; if (FPMO->hasAllowReciprocal()) Flags |= bitc::AllowReciprocal; - if (FPMO->hasAllowContract()) - Flags |= bitc::AllowContract; - if (FPMO->hasApproxFunc()) - Flags |= bitc::ApproxFunc; } return Flags; diff --git a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp index 46a8ab3..991ee5b 100644 --- a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp +++ b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp @@ -1796,6 +1796,8 @@ bool PolynomialMultiplyRecognize::recognize() { IterCount = CV->getValue()->getZExtValue() + 1; Value *CIV = getCountIV(LoopB); + if (CIV == nullptr) + return false; ParsedValues PV; Simplifier PreSimp; PV.IterCount = IterCount; diff --git a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.cpp b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.cpp index 54aeda2..32bc8bb 100644 --- a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.cpp @@ -154,6 +154,9 @@ void LoongArchInstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB, Register VReg) const { MachineFunction *MF = MBB.getParent(); MachineFrameInfo &MFI = MF->getFrameInfo(); + DebugLoc DL; + if (I != MBB.end()) + DL = I->getDebugLoc(); unsigned Opcode; if (LoongArch::GPRRegClass.hasSubClassEq(RC)) @@ -177,7 +180,7 @@ void LoongArchInstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB, MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOLoad, MFI.getObjectSize(FI), MFI.getObjectAlign(FI)); - BuildMI(MBB, I, DebugLoc(), get(Opcode), DstReg) + BuildMI(MBB, I, DL, get(Opcode), DstReg) .addFrameIndex(FI) .addImm(0) .addMemOperand(MMO); diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index 65e1893..d34f45f 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -14,7 +14,7 @@ #include "NVPTX.h" #include "NVPTXUtilities.h" #include "llvm/ADT/StringRef.h" -#include "llvm/IR/NVVMIntrinsicFlags.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/MC/MCExpr.h" #include "llvm/MC/MCInst.h" #include "llvm/MC/MCInstrInfo.h" diff --git a/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp index f940dc0..c03ef8d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp @@ -14,6 +14,7 @@ #include "MCTargetDesc/NVPTXBaseInfo.h" #include "NVPTX.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" @@ -49,39 +50,34 @@ static std::string getHash(StringRef Str) { return llvm::utohexstr(Hash.low(), /*LowerCase=*/true); } -static void addKernelMetadata(Module &M, GlobalValue *GV) { +static void addKernelMetadata(Module &M, Function *F) { llvm::LLVMContext &Ctx = M.getContext(); // Get "nvvm.annotations" metadata node. llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); - llvm::Metadata *KernelMDVals[] = { - llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "kernel"), - llvm::ConstantAsMetadata::get( - llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; - // This kernel is only to be called single-threaded. llvm::Metadata *ThreadXMDVals[] = { - llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidx"), + llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidx"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; llvm::Metadata *ThreadYMDVals[] = { - llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidy"), + llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidy"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; llvm::Metadata *ThreadZMDVals[] = { - llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidz"), + llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidz"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; llvm::Metadata *BlockMDVals[] = { - llvm::ConstantAsMetadata::get(GV), + llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxclusterrank"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))}; // Append metadata to nvvm.annotations. - MD->addOperand(llvm::MDNode::get(Ctx, KernelMDVals)); + F->setCallingConv(CallingConv::PTX_Kernel); MD->addOperand(llvm::MDNode::get(Ctx, ThreadXMDVals)); MD->addOperand(llvm::MDNode::get(Ctx, ThreadYMDVals)); MD->addOperand(llvm::MDNode::get(Ctx, ThreadZMDVals)); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index c51729e..ef97844 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -14,10 +14,11 @@ #include "NVPTXUtilities.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/ISDOpcodes.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/NVVMIntrinsicFlags.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" @@ -2449,6 +2450,11 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) { return true; } +static inline bool isAddLike(const SDValue V) { + return V.getOpcode() == ISD::ADD || + (V->getOpcode() == ISD::OR && V->getFlags().hasDisjoint()); +} + // SelectDirectAddr - Match a direct address for DAG. // A direct address could be a globaladdress or externalsymbol. bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) { @@ -2475,7 +2481,7 @@ bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) { // symbol+offset bool NVPTXDAGToDAGISel::SelectADDRsi_imp( SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset, MVT mvt) { - if (Addr.getOpcode() == ISD::ADD) { + if (isAddLike(Addr)) { if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) { SDValue base = Addr.getOperand(0); if (SelectDirectAddr(base, Base)) { @@ -2512,7 +2518,7 @@ bool NVPTXDAGToDAGISel::SelectADDRri_imp( Addr.getOpcode() == ISD::TargetGlobalAddress) return false; // direct calls. - if (Addr.getOpcode() == ISD::ADD) { + if (isAddLike(Addr)) { if (SelectDirectAddr(Addr.getOperand(0), Addr)) { return false; } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 4a98fe2..c9b7e87 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -261,6 +261,9 @@ public: return true; } + bool isFAbsFree(EVT VT) const override { return true; } + bool isFNegFree(EVT VT) const override { return true; } + private: const NVPTXSubtarget &STI; // cache the subtarget here SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp index 42043ad..74ce6a9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp @@ -34,19 +34,18 @@ void NVPTXSubtarget::anchor() {} NVPTXSubtarget &NVPTXSubtarget::initializeSubtargetDependencies(StringRef CPU, StringRef FS) { - // Provide the default CPU if we don't have one. - TargetName = std::string(CPU.empty() ? "sm_30" : CPU); + TargetName = std::string(CPU); - ParseSubtargetFeatures(TargetName, /*TuneCPU*/ TargetName, FS); + ParseSubtargetFeatures(getTargetName(), /*TuneCPU=*/getTargetName(), FS); - // Re-map SM version numbers, SmVersion carries the regular SMs which do - // have relative order, while FullSmVersion allows distinguishing sm_90 from - // sm_90a, which would *not* be a subset of sm_91. - SmVersion = getSmVersion(); + // Re-map SM version numbers, SmVersion carries the regular SMs which do + // have relative order, while FullSmVersion allows distinguishing sm_90 from + // sm_90a, which would *not* be a subset of sm_91. + SmVersion = getSmVersion(); - // Set default to PTX 6.0 (CUDA 9.0) - if (PTXVersion == 0) { - PTXVersion = 60; + // Set default to PTX 6.0 (CUDA 9.0) + if (PTXVersion == 0) { + PTXVersion = 60; } return *this; diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index 7555a23..bbc1cca 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -111,7 +111,12 @@ public: // - 0 represents base GPU model, // - non-zero value identifies particular architecture-accelerated variant. bool hasAAFeatures() const { return getFullSmVersion() % 10; } - std::string getTargetName() const { return TargetName; } + + // If the user did not provide a target we default to the `sm_30` target. + std::string getTargetName() const { + return TargetName.empty() ? "sm_30" : TargetName; + } + bool hasTargetName() const { return !TargetName.empty(); } // Get maximum value of required alignments among the supported data types. // From the PTX ISA doc, section 8.2.3: diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index b3b2880..6d4b82a 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -255,7 +255,10 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { PB.registerPipelineStartEPCallback( [this](ModulePassManager &PM, OptimizationLevel Level) { FunctionPassManager FPM; - FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion())); + // We do not want to fold out calls to nvvm.reflect early if the user + // has not provided a target architecture just yet. + if (Subtarget.hasTargetName()) + FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion())); // Note: NVVMIntrRangePass was causing numerical discrepancies at one // point, if issues crop up, consider disabling. FPM.addPass(NVVMIntrRangePass()); diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp index 98bffd9..0f2bec7 100644 --- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp @@ -311,11 +311,13 @@ std::optional<unsigned> getMaxNReg(const Function &F) { } bool isKernelFunction(const Function &F) { + if (F.getCallingConv() == CallingConv::PTX_Kernel) + return true; + if (const auto X = findOneNVVMAnnotation(&F, "kernel")) return (*X == 1); - // There is no NVVM metadata, check the calling convention - return F.getCallingConv() == CallingConv::PTX_Kernel; + return false; } MaybeAlign getAlign(const Function &F, unsigned Index) { diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp index 56525a1..0cd584c 100644 --- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp +++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp @@ -21,6 +21,7 @@ #include "NVPTX.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -219,7 +220,13 @@ bool NVVMReflect::runOnFunction(Function &F) { return runNVVMReflect(F, SmVersion); } -NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {} +NVVMReflectPass::NVVMReflectPass() { + // Get the CPU string from the command line if not provided. + std::string MCPU = codegen::getMCPU(); + StringRef SM = MCPU; + if (!SM.consume_front("sm_") || SM.consumeInteger(10, SmVersion)) + SmVersion = 0; +} PreservedAnalyses NVVMReflectPass::run(Function &F, FunctionAnalysisManager &AM) { diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt index 4466164..98d3615 100644 --- a/llvm/lib/Target/RISCV/CMakeLists.txt +++ b/llvm/lib/Target/RISCV/CMakeLists.txt @@ -15,6 +15,7 @@ tablegen(LLVM RISCVGenRegisterBank.inc -gen-register-bank) tablegen(LLVM RISCVGenRegisterInfo.inc -gen-register-info) tablegen(LLVM RISCVGenSearchableTables.inc -gen-searchable-tables) tablegen(LLVM RISCVGenSubtargetInfo.inc -gen-subtarget) +tablegen(LLVM RISCVGenExegesis.inc -gen-exegesis) set(LLVM_TARGET_DEFINITIONS RISCVGISel.td) tablegen(LLVM RISCVGenGlobalISel.inc -gen-global-isel) diff --git a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp index 3012283..a490910 100644 --- a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp +++ b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp @@ -698,6 +698,8 @@ DecodeStatus RISCVDisassembler::getInstruction32(MCInst &MI, uint64_t &Size, TRY_TO_DECODE_FEATURE( RISCV::FeatureVendorXqcicli, DecoderTableXqcicli32, "Qualcomm uC Conditional Load Immediate custom opcode table"); + TRY_TO_DECODE_FEATURE(RISCV::FeatureVendorXqcicm, DecoderTableXqcicm32, + "Qualcomm uC Conditional Move custom opcode table"); TRY_TO_DECODE(true, DecoderTable32, "RISCV32 table"); return MCDisassembler::Fail; @@ -727,6 +729,9 @@ DecodeStatus RISCVDisassembler::getInstruction16(MCInst &MI, uint64_t &Size, TRY_TO_DECODE_FEATURE( RISCV::FeatureVendorXqciac, DecoderTableXqciac16, "Qualcomm uC Load-Store Address Calculation custom 16bit opcode table"); + TRY_TO_DECODE_FEATURE( + RISCV::FeatureVendorXqcicm, DecoderTableXqcicm16, + "Qualcomm uC Conditional Move custom 16bit opcode table"); TRY_TO_DECODE_AND_ADD_SP(STI.hasFeature(RISCV::FeatureVendorXwchc), DecoderTableXwchc16, "WCH QingKe XW custom opcode table"); diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp index ef85057..3f1539d 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp @@ -80,7 +80,6 @@ private: bool selectFPCompare(MachineInstr &MI, MachineIRBuilder &MIB) const; void emitFence(AtomicOrdering FenceOrdering, SyncScope::ID FenceSSID, MachineIRBuilder &MIB) const; - bool selectMergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const; bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const; ComplexRendererFns selectShiftMask(MachineOperand &Root, @@ -732,8 +731,6 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) { } case TargetOpcode::G_IMPLICIT_DEF: return selectImplicitDef(MI, MIB); - case TargetOpcode::G_MERGE_VALUES: - return selectMergeValues(MI, MIB); case TargetOpcode::G_UNMERGE_VALUES: return selectUnmergeValues(MI, MIB); default: @@ -741,26 +738,13 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) { } } -bool RISCVInstructionSelector::selectMergeValues(MachineInstr &MI, - MachineIRBuilder &MIB) const { - assert(MI.getOpcode() == TargetOpcode::G_MERGE_VALUES); - - // Build a F64 Pair from operands - if (MI.getNumOperands() != 3) - return false; - Register Dst = MI.getOperand(0).getReg(); - Register Lo = MI.getOperand(1).getReg(); - Register Hi = MI.getOperand(2).getReg(); - if (!isRegInFprb(Dst) || !isRegInGprb(Lo) || !isRegInGprb(Hi)) - return false; - MI.setDesc(TII.get(RISCV::BuildPairF64Pseudo)); - return constrainSelectedInstRegOperands(MI, TII, TRI, RBI); -} - bool RISCVInstructionSelector::selectUnmergeValues( MachineInstr &MI, MachineIRBuilder &MIB) const { assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES); + if (!Subtarget->hasStdExtZfa()) + return false; + // Split F64 Src into two s32 parts if (MI.getNumOperands() != 3) return false; @@ -769,8 +753,17 @@ bool RISCVInstructionSelector::selectUnmergeValues( Register Hi = MI.getOperand(1).getReg(); if (!isRegInFprb(Src) || !isRegInGprb(Lo) || !isRegInGprb(Hi)) return false; - MI.setDesc(TII.get(RISCV::SplitF64Pseudo)); - return constrainSelectedInstRegOperands(MI, TII, TRI, RBI); + + MachineInstr *ExtractLo = MIB.buildInstr(RISCV::FMV_X_W_FPR64, {Lo}, {Src}); + if (!constrainSelectedInstRegOperands(*ExtractLo, TII, TRI, RBI)) + return false; + + MachineInstr *ExtractHi = MIB.buildInstr(RISCV::FMVH_X_D, {Hi}, {Src}); + if (!constrainSelectedInstRegOperands(*ExtractHi, TII, TRI, RBI)) + return false; + + MI.eraseFromParent(); + return true; } bool RISCVInstructionSelector::replacePtrWithInt(MachineOperand &Op, diff --git a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp index 8284737..6f06459 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp @@ -21,6 +21,7 @@ #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineMemOperand.h" +#include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/CodeGen/ValueTypes.h" @@ -132,7 +133,14 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) auto PtrVecTys = {nxv1p0, nxv2p0, nxv4p0, nxv8p0, nxv16p0}; - getActionDefinitionsBuilder({G_ADD, G_SUB, G_AND, G_OR, G_XOR}) + getActionDefinitionsBuilder({G_ADD, G_SUB}) + .legalFor({sXLen}) + .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST)) + .customFor(ST.is64Bit(), {s32}) + .widenScalarToNextPow2(0) + .clampScalar(0, sXLen, sXLen); + + getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) .legalFor({sXLen}) .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST)) .widenScalarToNextPow2(0) @@ -1330,6 +1338,24 @@ bool RISCVLegalizerInfo::legalizeCustom( return true; return Helper.lowerConstant(MI); } + case TargetOpcode::G_SUB: + case TargetOpcode::G_ADD: { + Helper.Observer.changingInstr(MI); + Helper.widenScalarSrc(MI, sXLen, 1, TargetOpcode::G_ANYEXT); + Helper.widenScalarSrc(MI, sXLen, 2, TargetOpcode::G_ANYEXT); + + Register DstALU = MRI.createGenericVirtualRegister(sXLen); + + MachineOperand &MO = MI.getOperand(0); + MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt()); + auto DstSext = MIRBuilder.buildSExtInReg(sXLen, DstALU, 32); + + MIRBuilder.buildInstr(TargetOpcode::G_TRUNC, {MO}, {DstSext}); + MO.setReg(DstALU); + + Helper.Observer.changedInstr(MI); + return true; + } case TargetOpcode::G_SEXT_INREG: { LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); int64_t SizeInBits = MI.getOperand(2).getImm(); diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp index eab4a5e..0cb1ef0 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp @@ -38,9 +38,12 @@ std::optional<MCFixupKind> RISCVAsmBackend::getFixupKind(StringRef Name) const { if (STI.getTargetTriple().isOSBinFormatELF()) { unsigned Type; Type = llvm::StringSwitch<unsigned>(Name) -#define ELF_RELOC(X, Y) .Case(#X, Y) +#define ELF_RELOC(NAME, ID) .Case(#NAME, ID) #include "llvm/BinaryFormat/ELFRelocs/RISCV.def" #undef ELF_RELOC +#define ELF_RISCV_NONSTANDARD_RELOC(_VENDOR, NAME, ID) .Case(#NAME, ID) +#include "llvm/BinaryFormat/ELFRelocs/RISCV_nonstandard.def" +#undef ELF_RISCV_NONSTANDARD_RELOC .Case("BFD_RELOC_NONE", ELF::R_RISCV_NONE) .Case("BFD_RELOC_32", ELF::R_RISCV_32) .Case("BFD_RELOC_64", ELF::R_RISCV_64) diff --git a/llvm/lib/Target/RISCV/RISCV.td b/llvm/lib/Target/RISCV/RISCV.td index 9631241..4e0c64a 100644 --- a/llvm/lib/Target/RISCV/RISCV.td +++ b/llvm/lib/Target/RISCV/RISCV.td @@ -64,6 +64,12 @@ include "RISCVSchedXiangShanNanHu.td" include "RISCVProcessors.td" //===----------------------------------------------------------------------===// +// Pfm Counters +//===----------------------------------------------------------------------===// + +include "RISCVPfmCounters.td" + +//===----------------------------------------------------------------------===// // Define the RISC-V target. //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index 0074be3..01bc5387 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -1294,6 +1294,14 @@ def HasVendorXqcicli AssemblerPredicate<(all_of FeatureVendorXqcicli), "'Xqcicli' (Qualcomm uC Conditional Load Immediate Extension)">; +def FeatureVendorXqcicm + : RISCVExperimentalExtension<0, 2, "Qualcomm uC Conditional Move Extension", + [FeatureStdExtZca]>; +def HasVendorXqcicm + : Predicate<"Subtarget->hasVendorXqcicm()">, + AssemblerPredicate<(all_of FeatureVendorXqcicm), + "'Xqcicm' (Qualcomm uC Conditional Move Extension)">; + //===----------------------------------------------------------------------===// // LLVM specific features and extensions //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 4a0304f..6c58989 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -18385,7 +18385,7 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift( auto *C2 = dyn_cast<ConstantSDNode>(N->getOperand(1)); // Bail if we might break a sh{1,2,3}add pattern. - if (Subtarget.hasStdExtZba() && C2->getZExtValue() >= 1 && + if (Subtarget.hasStdExtZba() && C2 && C2->getZExtValue() >= 1 && C2->getZExtValue() <= 3 && N->hasOneUse() && N->user_begin()->getOpcode() == ISD::ADD && !isUsedByLdSt(*N->user_begin(), nullptr) && @@ -20273,13 +20273,11 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, for (auto &Reg : RegsToPass) Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType())); - if (!IsTailCall) { - // Add a register mask operand representing the call-preserved registers. - const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); - const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv); - assert(Mask && "Missing call preserved mask for calling convention"); - Ops.push_back(DAG.getRegisterMask(Mask)); - } + // Add a register mask operand representing the call-preserved registers. + const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); + const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv); + assert(Mask && "Missing call preserved mask for calling convention"); + Ops.push_back(DAG.getRegisterMask(Mask)); // Glue the call to the argument copies, if any. if (Glue.getNode()) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td index ae969bff8..349bc36 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td @@ -23,7 +23,9 @@ def SDT_RISCVSplitF64 : SDTypeProfile<2, 1, [SDTCisVT<0, i32>, SDTCisVT<2, f64>]>; def RISCVBuildPairF64 : SDNode<"RISCVISD::BuildPairF64", SDT_RISCVBuildPairF64>; +def : GINodeEquiv<G_MERGE_VALUES, RISCVBuildPairF64>; def RISCVSplitF64 : SDNode<"RISCVISD::SplitF64", SDT_RISCVSplitF64>; +def : GINodeEquiv<G_UNMERGE_VALUES, RISCVSplitF64>; def AddrRegImmINX : ComplexPattern<iPTR, 2, "SelectAddrRegImmRV32Zdinx">; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td index 5e6722c..6f15646 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td @@ -150,6 +150,22 @@ class QCILICC<bits<3> funct3, bits<2> funct2, DAGOperand InTyRs2, string opcodes let Inst{31-25} = {simm, funct2}; } +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +class QCIMVCC<bits<3> funct3, string opcodestr> + : RVInstR4<0b00, funct3, OPC_CUSTOM_2, (outs GPRNoX0:$rd), + (ins GPRNoX0:$rs1, GPRNoX0:$rs2, GPRNoX0:$rs3), + opcodestr, "$rd, $rs1, $rs2, $rs3">; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +class QCIMVCCI<bits<3> funct3, string opcodestr, DAGOperand immType> + : RVInstR4<0b10, funct3, OPC_CUSTOM_2, (outs GPRNoX0:$rd), + (ins GPRNoX0:$rs1, immType:$imm, GPRNoX0:$rs3), + opcodestr, "$rd, $rs1, $imm, $rs3"> { + bits<5> imm; + + let rs2 = imm; +} + //===----------------------------------------------------------------------===// // Instructions //===----------------------------------------------------------------------===// @@ -270,6 +286,32 @@ let Predicates = [HasVendorXqcicli, IsRV32], DecoderNamespace = "Xqcicli" in { def QC_LIGEUI : QCILICC<0b111, 0b11, uimm5, "qc.ligeui">; } // Predicates = [HasVendorXqcicli, IsRV32], DecoderNamespace = "Xqcicli" +let Predicates = [HasVendorXqcicm, IsRV32], DecoderNamespace = "Xqcicm" in { +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in + def QC_C_MVEQZ : RVInst16CL<0b101, 0b10, (outs GPRC:$rd_wb), + (ins GPRC:$rd, GPRC:$rs1), + "qc.c.mveqz", "$rd, $rs1"> { + let Constraints = "$rd = $rd_wb"; + + let Inst{12-10} = 0b011; + let Inst{6-5} = 0b00; + } + + def QC_MVEQ : QCIMVCC<0b000, "qc.mveq">; + def QC_MVNE : QCIMVCC<0b001, "qc.mvne">; + def QC_MVLT : QCIMVCC<0b100, "qc.mvlt">; + def QC_MVGE : QCIMVCC<0b101, "qc.mvge">; + def QC_MVLTU : QCIMVCC<0b110, "qc.mvltu">; + def QC_MVGEU : QCIMVCC<0b111, "qc.mvgeu">; + + def QC_MVEQI : QCIMVCCI<0b000, "qc.mveqi", simm5>; + def QC_MVNEI : QCIMVCCI<0b001, "qc.mvnei", simm5>; + def QC_MVLTI : QCIMVCCI<0b100, "qc.mvlti", simm5>; + def QC_MVGEI : QCIMVCCI<0b101, "qc.mvgei", simm5>; + def QC_MVLTUI : QCIMVCCI<0b110, "qc.mvltui", uimm5>; + def QC_MVGEUI : QCIMVCCI<0b111, "qc.mvgeui", uimm5>; +} // Predicates = [HasVendorXqcicm, IsRV32], DecoderNamespace = "Xqcicm" + //===----------------------------------------------------------------------===// // Aliases //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVPfmCounters.td b/llvm/lib/Target/RISCV/RISCVPfmCounters.td new file mode 100644 index 0000000..013e789 --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVPfmCounters.td @@ -0,0 +1,18 @@ +//===---- RISCVPfmCounters.td - RISC-V Hardware Counters ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This describes the available hardware counters for RISC-V. +// +//===----------------------------------------------------------------------===// + +def CpuCyclesPfmCounter : PfmCounter<"CYCLES">; + +def DefaultPfmCounters : ProcPfmCounters { + let CycleCounter = CpuCyclesPfmCounter; +} +def : PfmCountersDefaultBinding<DefaultPfmCounters>; diff --git a/llvm/lib/Target/RISCV/RISCVProcessors.td b/llvm/lib/Target/RISCV/RISCVProcessors.td index 61c7c21..6dfed7dd 100644 --- a/llvm/lib/Target/RISCV/RISCVProcessors.td +++ b/llvm/lib/Target/RISCV/RISCVProcessors.td @@ -321,6 +321,25 @@ def SIFIVE_P470 : RISCVProcessorModel<"sifive-p470", SiFiveP400Model, [TuneNoSinkSplatOperands, TuneVXRMPipelineFlush])>; +defvar SiFiveP500TuneFeatures = [TuneNoDefaultUnroll, + TuneConditionalCompressedMoveFusion, + TuneLUIADDIFusion, + TuneAUIPCADDIFusion, + TunePostRAScheduler]; + +def SIFIVE_P550 : RISCVProcessorModel<"sifive-p550", NoSchedModel, + [Feature64Bit, + FeatureStdExtI, + FeatureStdExtZifencei, + FeatureStdExtM, + FeatureStdExtA, + FeatureStdExtF, + FeatureStdExtD, + FeatureStdExtC, + FeatureStdExtZba, + FeatureStdExtZbb], + SiFiveP500TuneFeatures>; + def SIFIVE_P670 : RISCVProcessorModel<"sifive-p670", SiFiveP600Model, !listconcat(RVA22U64Features, [FeatureStdExtV, diff --git a/llvm/lib/Target/RISCV/RISCVSchedSiFiveP400.td b/llvm/lib/Target/RISCV/RISCVSchedSiFiveP400.td index a86c255..396cbe2c 100644 --- a/llvm/lib/Target/RISCV/RISCVSchedSiFiveP400.td +++ b/llvm/lib/Target/RISCV/RISCVSchedSiFiveP400.td @@ -182,7 +182,7 @@ def P400WriteCMOV : SchedWriteRes<[SiFiveP400Branch, SiFiveP400IEXQ1]> { } def : InstRW<[P400WriteCMOV], (instrs PseudoCCMOVGPRNoX0)>; -let Latency = 3 in { +let Latency = 2 in { // Integer multiplication def : WriteRes<WriteIMul, [SiFiveP400MulDiv]>; def : WriteRes<WriteIMul32, [SiFiveP400MulDiv]>; diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index 32d5526..ad61a77 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -50,7 +50,10 @@ public: StringRef getPassName() const override { return PASS_NAME; } private: - bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI); + std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp); + /// Returns the largest common VL MachineOperand that may be used to optimize + /// MI. Returns std::nullopt if it failed to find a suitable VL. + std::optional<MachineOperand> checkUsers(MachineInstr &MI); bool tryReduceVL(MachineInstr &MI); bool isCandidate(const MachineInstr &MI) const; }; @@ -76,11 +79,6 @@ static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) { /// Represents the EMUL and EEW of a MachineOperand. struct OperandInfo { - enum class State { - Unknown, - Known, - } S; - // Represent as 1,2,4,8, ... and fractional indicator. This is because // EMUL can take on values that don't map to RISCVII::VLMUL values exactly. // For example, a mask operand can have an EMUL less than MF8. @@ -89,34 +87,32 @@ struct OperandInfo { unsigned Log2EEW; OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW) - : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) { - } + : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {} OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW) - : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {} + : EMUL(EMUL), Log2EEW(Log2EEW) {} - OperandInfo() : S(State::Unknown) {} + OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {} - bool isUnknown() const { return S == State::Unknown; } - bool isKnown() const { return S == State::Known; } + OperandInfo() = delete; static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) { - assert(A.isKnown() && B.isKnown() && "Both operands must be known"); - return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first && A.EMUL->second == B.EMUL->second; } + static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) { + return A.Log2EEW == B.Log2EEW; + } + void print(raw_ostream &OS) const { - if (isUnknown()) { - OS << "Unknown"; - return; - } - assert(EMUL && "Expected EMUL to have value"); - OS << "EMUL: m"; - if (EMUL->second) - OS << "f"; - OS << EMUL->first; + if (EMUL) { + OS << "EMUL: m"; + if (EMUL->second) + OS << "f"; + OS << EMUL->first; + } else + OS << "EMUL: unknown\n"; OS << ", EEW: " << (1 << Log2EEW); } }; @@ -127,30 +123,18 @@ static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) { return OS; } -namespace llvm { -namespace RISCVVType { -/// Return the RISCVII::VLMUL that is two times VLMul. -/// Precondition: VLMul is not LMUL_RESERVED or LMUL_8. -static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) { - switch (VLMul) { - case RISCVII::VLMUL::LMUL_F8: - return RISCVII::VLMUL::LMUL_F4; - case RISCVII::VLMUL::LMUL_F4: - return RISCVII::VLMUL::LMUL_F2; - case RISCVII::VLMUL::LMUL_F2: - return RISCVII::VLMUL::LMUL_1; - case RISCVII::VLMUL::LMUL_1: - return RISCVII::VLMUL::LMUL_2; - case RISCVII::VLMUL::LMUL_2: - return RISCVII::VLMUL::LMUL_4; - case RISCVII::VLMUL::LMUL_4: - return RISCVII::VLMUL::LMUL_8; - case RISCVII::VLMUL::LMUL_8: - default: - llvm_unreachable("Could not multiply VLMul by 2"); - } +LLVM_ATTRIBUTE_UNUSED +static raw_ostream &operator<<(raw_ostream &OS, + const std::optional<OperandInfo> &OI) { + if (OI) + OI->print(OS); + else + OS << "nullopt"; + return OS; } +namespace llvm { +namespace RISCVVType { /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and /// SEW are from the TSFlags of MI. static std::pair<unsigned, bool> @@ -180,24 +164,22 @@ getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { } // end namespace RISCVVType } // end namespace llvm -/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2). -/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI. -static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor, - const MachineInstr &MI, - const MachineOperand &MO) { - RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); +/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2). +/// SEW comes from TSFlags of MI. +static unsigned getIntegerExtensionOperandEEW(unsigned Factor, + const MachineInstr &MI, + const MachineOperand &MO) { unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); + return MILog2SEW; unsigned MISEW = 1 << MILog2SEW; unsigned EEW = MISEW / Factor; unsigned Log2EEW = Log2_32(EEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), - Log2EEW); + return Log2EEW; } /// Check whether MO is a mask operand of MI. @@ -211,18 +193,15 @@ static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; } -/// Return the OperandInfo for MO. -static OperandInfo getOperandInfo(const MachineOperand &MO, - const MachineRegisterInfo *MRI) { +static std::optional<unsigned> +getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { const MachineInstr &MI = *MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); assert(RVV && "Could not find MI in PseudoTable"); - // MI has a VLMUL and SEW associated with it. The RVV specification defines - // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and - // MI.SEW. - RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); + // MI has a SEW associated with it. The RVV specification defines + // the EEW of each operand and definition in relation to MI.SEW. unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); @@ -233,13 +212,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // since they must preserve the entire register content. if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() && (MO.getReg() != RISCV::NoRegister)) - return {}; + return std::nullopt; bool IsMODef = MO.getOperandNo() == 0; - // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL + // All mask operands have EEW=1 if (isMaskOperand(MI, MO, MRI)) - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; // switch against BaseInstr to reduce number of cases that need to be // considered. @@ -256,55 +235,65 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Loads and Stores // Vector Unit-Stride Instructions // Vector Strided Instructions - /// Dest EEW encoded in the instruction and EMUL=(EEW/SEW)*LMUL + /// Dest EEW encoded in the instruction + case RISCV::VLM_V: + case RISCV::VSM_V: + return 0; + case RISCV::VLE8_V: case RISCV::VSE8_V: + case RISCV::VLSE8_V: case RISCV::VSSE8_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3); + return 3; + case RISCV::VLE16_V: case RISCV::VSE16_V: + case RISCV::VLSE16_V: case RISCV::VSSE16_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4); + return 4; + case RISCV::VLE32_V: case RISCV::VSE32_V: + case RISCV::VLSE32_V: case RISCV::VSSE32_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5); + return 5; + case RISCV::VLE64_V: case RISCV::VSE64_V: + case RISCV::VLSE64_V: case RISCV::VSSE64_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6); + return 6; // Vector Indexed Instructions // vs(o|u)xei<eew>.v - // Dest/Data (operand 0) EEW=SEW, EMUL=LMUL. Source EEW=<eew> and - // EMUL=(EEW/SEW)*LMUL. + // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>. case RISCV::VLUXEI8_V: case RISCV::VLOXEI8_V: case RISCV::VSUXEI8_V: case RISCV::VSOXEI8_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3); + return MILog2SEW; + return 3; } case RISCV::VLUXEI16_V: case RISCV::VLOXEI16_V: case RISCV::VSUXEI16_V: case RISCV::VSOXEI16_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4); + return MILog2SEW; + return 4; } case RISCV::VLUXEI32_V: case RISCV::VLOXEI32_V: case RISCV::VSUXEI32_V: case RISCV::VSOXEI32_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5); + return MILog2SEW; + return 5; } case RISCV::VLUXEI64_V: case RISCV::VLOXEI64_V: case RISCV::VSUXEI64_V: case RISCV::VSOXEI64_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6); + return MILog2SEW; + return 6; } // Vector Integer Arithmetic Instructions @@ -318,7 +307,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VRSUB_VX: // Vector Bitwise Logical Instructions // Vector Single-Width Shift Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VAND_VI: case RISCV::VAND_VV: case RISCV::VAND_VX: @@ -338,7 +327,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VSRA_VV: case RISCV::VSRA_VX: // Vector Integer Min/Max Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMINU_VV: case RISCV::VMINU_VX: case RISCV::VMIN_VV: @@ -348,7 +337,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMAX_VV: case RISCV::VMAX_VX: // Vector Single-Width Integer Multiply Instructions - // Source and Dest EEW=SEW and EMUL=LMUL. + // Source and Dest EEW=SEW. case RISCV::VMUL_VV: case RISCV::VMUL_VX: case RISCV::VMULH_VV: @@ -358,7 +347,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMULHSU_VV: case RISCV::VMULHSU_VX: // Vector Integer Divide Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VDIVU_VV: case RISCV::VDIVU_VX: case RISCV::VDIV_VV: @@ -368,7 +357,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VREM_VV: case RISCV::VREM_VX: // Vector Single-Width Integer Multiply-Add Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMACC_VV: case RISCV::VMACC_VX: case RISCV::VNMSAC_VV: @@ -379,8 +368,8 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VNMSUB_VX: // Vector Integer Merge Instructions // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= - // (EEW/SEW)*LMUL. Mask operand is handled before this switch. + // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled + // before this switch. case RISCV::VMERGE_VIM: case RISCV::VMERGE_VVM: case RISCV::VMERGE_VXM: @@ -393,7 +382,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Fixed-Point Arithmetic Instructions // Vector Single-Width Saturating Add and Subtract // Vector Single-Width Averaging Add and Subtract - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMV_V_I: case RISCV::VMV_V_V: case RISCV::VMV_V_X: @@ -416,12 +405,12 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VASUB_VV: case RISCV::VASUB_VX: // Vector Single-Width Fractional Multiply with Rounding and Saturation - // EEW=SEW. EMUL=LMUL. The instruction produces 2*SEW product internally but + // EEW=SEW. The instruction produces 2*SEW product internally but // saturates to fit into SEW bits. case RISCV::VSMUL_VV: case RISCV::VSMUL_VX: // Vector Single-Width Scaling Shift Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VSSRL_VI: case RISCV::VSSRL_VV: case RISCV::VSSRL_VX: @@ -431,13 +420,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Permutation Instructions // Integer Scalar Move Instructions // Floating-Point Scalar Move Instructions - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VMV_X_S: case RISCV::VMV_S_X: case RISCV::VFMV_F_S: case RISCV::VFMV_S_F: // Vector Slide Instructions - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VSLIDEUP_VI: case RISCV::VSLIDEUP_VX: case RISCV::VSLIDEDOWN_VI: @@ -447,12 +436,12 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VSLIDE1DOWN_VX: case RISCV::VFSLIDE1DOWN_VF: // Vector Register Gather Instructions - // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1. + // EEW=SEW. For mask operand, EEW=1. case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: // Vector Compress Instruction - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VCOMPRESS_VM: // Vector Element Index Instruction case RISCV::VID_V: @@ -499,10 +488,10 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VFCVT_F_X_V: // Vector Floating-Point Merge Instruction case RISCV::VFMERGE_VFM: - return OperandInfo(MIVLMul, MILog2SEW); + return MILog2SEW; // Vector Widening Integer Add/Subtract - // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL. + // Def uses EEW=2*SEW . Operands use EEW=SEW. case RISCV::VWADDU_VV: case RISCV::VWADDU_VX: case RISCV::VWSUBU_VV: @@ -513,7 +502,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWSUB_VX: case RISCV::VWSLL_VI: // Vector Widening Integer Multiply Instructions - // Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW. + // Destination EEW=2*SEW. Source EEW=SEW. case RISCV::VWMUL_VV: case RISCV::VWMUL_VX: case RISCV::VWMULSU_VV: @@ -521,7 +510,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWMULU_VV: case RISCV::VWMULU_VX: // Vector Widening Integer Multiply-Add Instructions - // Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL. + // Destination EEW=2*SEW. Source EEW=SEW. // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which // is then added to the 2*SEW-bit Dest. These instructions never have a // passthru operand. @@ -542,7 +531,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VFWNMSAC_VF: case RISCV::VFWNMSAC_VV: // Vector Widening Floating-Point Add/Subtract Instructions - // Dest EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=2*SEW. Source EEW=SEW. case RISCV::VFWADD_VV: case RISCV::VFWADD_VF: case RISCV::VFWSUB_VV: @@ -559,12 +548,10 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VFWCVT_F_X_V: case RISCV::VFWCVT_F_F_V: { unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - IsMODef ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } - // Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL + // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW. case RISCV::VWADDU_WV: case RISCV::VWADDU_WX: case RISCV::VWSUBU_WV: @@ -581,25 +568,22 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsMODef || IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } // Vector Integer Extension case RISCV::VZEXT_VF2: case RISCV::VSEXT_VF2: - return getIntegerExtensionOperandInfo(2, MI, MO); + return getIntegerExtensionOperandEEW(2, MI, MO); case RISCV::VZEXT_VF4: case RISCV::VSEXT_VF4: - return getIntegerExtensionOperandInfo(4, MI, MO); + return getIntegerExtensionOperandEEW(4, MI, MO); case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: - return getIntegerExtensionOperandInfo(8, MI, MO); + return getIntegerExtensionOperandEEW(8, MI, MO); // Vector Narrowing Integer Right Shift Instructions - // Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has - // EEW=SEW EMUL=LMUL. + // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW case RISCV::VNSRL_WX: case RISCV::VNSRL_WI: case RISCV::VNSRL_WV: @@ -607,7 +591,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VNSRA_WV: case RISCV::VNSRA_WX: // Vector Narrowing Fixed-Point Clip Instructions - // Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL + // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW. case RISCV::VNCLIPU_WI: case RISCV::VNCLIPU_WV: case RISCV::VNCLIPU_WX: @@ -626,9 +610,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } // Vector Mask Instructions @@ -636,7 +618,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // vmsbf.m set-before-first mask bit // vmsif.m set-including-first mask bit // vmsof.m set-only-first mask bit - // EEW=1 and EMUL=(EEW/SEW)*LMUL + // EEW=1 // We handle the cases when operand is a v0 mask operand above the switch, // but these instructions may use non-v0 mask operands and need to be handled // specifically. @@ -651,20 +633,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMSBF_M: case RISCV::VMSIF_M: case RISCV::VMSOF_M: { - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; } // Vector Iota Instruction - // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= - // (EEW/SEW)*LMUL. Mask operand is not handled before this switch. + // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled + // before this switch. case RISCV::VIOTA_M: { if (IsMODef || MO.getOperandNo() == 1) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return MILog2SEW; + return 0; } // Vector Integer Compare Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW. case RISCV::VMSEQ_VI: case RISCV::VMSEQ_VV: case RISCV::VMSEQ_VX: @@ -686,21 +668,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMSGT_VI: case RISCV::VMSGT_VX: // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. Mask - // source operand handled above this switch. + // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch. case RISCV::VMADC_VIM: case RISCV::VMADC_VVM: case RISCV::VMADC_VXM: case RISCV::VMSBC_VVM: case RISCV::VMSBC_VXM: - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW. case RISCV::VMADC_VV: case RISCV::VMADC_VI: case RISCV::VMADC_VX: case RISCV::VMSBC_VV: case RISCV::VMSBC_VX: // 13.13. Vector Floating-Point Compare Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW case RISCV::VMFEQ_VF: case RISCV::VMFEQ_VV: case RISCV::VMFNE_VF: @@ -712,15 +693,62 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMFGT_VF: case RISCV::VMFGE_VF: { if (IsMODef) - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); - return OperandInfo(MIVLMul, MILog2SEW); + return 0; + return MILog2SEW; + } + + // Vector Reduction Operations + // Vector Single-Width Integer Reduction Instructions + case RISCV::VREDAND_VS: + case RISCV::VREDMAX_VS: + case RISCV::VREDMAXU_VS: + case RISCV::VREDMIN_VS: + case RISCV::VREDMINU_VS: + case RISCV::VREDOR_VS: + case RISCV::VREDSUM_VS: + case RISCV::VREDXOR_VS: { + return MILog2SEW; } default: - return {}; + return std::nullopt; } } +static std::optional<OperandInfo> +getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) { + const MachineInstr &MI = *MO.getParent(); + const RISCVVPseudosTable::PseudoInfo *RVV = + RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); + assert(RVV && "Could not find MI in PseudoTable"); + + std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI); + if (!Log2EEW) + return std::nullopt; + + switch (RVV->BaseInstr) { + // Vector Reduction Operations + // Vector Single-Width Integer Reduction Instructions + // The Dest and VS1 only read element 0 of the vector register. Return just + // the EEW for these. + case RISCV::VREDAND_VS: + case RISCV::VREDMAX_VS: + case RISCV::VREDMAXU_VS: + case RISCV::VREDMIN_VS: + case RISCV::VREDMINU_VS: + case RISCV::VREDOR_VS: + case RISCV::VREDSUM_VS: + case RISCV::VREDXOR_VS: + if (MO.getOperandNo() != 2) + return OperandInfo(*Log2EEW); + break; + }; + + // All others have EMUL=EEW/SEW*LMUL + return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), + *Log2EEW); +} + /// Return true if this optimization should consider MI for VL reduction. This /// white-list approach simplifies this optimization for instructions that may /// have more complex semantics with relation to how it uses VL. @@ -732,6 +760,32 @@ static bool isSupportedInstr(const MachineInstr &MI) { return false; switch (RVV->BaseInstr) { + // Vector Unit-Stride Instructions + // Vector Strided Instructions + case RISCV::VLM_V: + case RISCV::VLE8_V: + case RISCV::VLSE8_V: + case RISCV::VLE16_V: + case RISCV::VLSE16_V: + case RISCV::VLE32_V: + case RISCV::VLSE32_V: + case RISCV::VLE64_V: + case RISCV::VLSE64_V: + // Vector Indexed Instructions + case RISCV::VLUXEI8_V: + case RISCV::VLOXEI8_V: + case RISCV::VLUXEI16_V: + case RISCV::VLOXEI16_V: + case RISCV::VLUXEI32_V: + case RISCV::VLOXEI32_V: + case RISCV::VLUXEI64_V: + case RISCV::VLOXEI64_V: { + for (const MachineMemOperand *MMO : MI.memoperands()) + if (MMO->isVolatile()) + return false; + return true; + } + // Vector Single-Width Integer Add and Subtract case RISCV::VADD_VI: case RISCV::VADD_VV: @@ -901,6 +955,30 @@ static bool isSupportedInstr(const MachineInstr &MI) { case RISCV::VMSOF_M: case RISCV::VIOTA_M: case RISCV::VID_V: + // Single-Width Floating-Point/Integer Type-Convert Instructions + case RISCV::VFCVT_XU_F_V: + case RISCV::VFCVT_X_F_V: + case RISCV::VFCVT_RTZ_XU_F_V: + case RISCV::VFCVT_RTZ_X_F_V: + case RISCV::VFCVT_F_XU_V: + case RISCV::VFCVT_F_X_V: + // Widening Floating-Point/Integer Type-Convert Instructions + case RISCV::VFWCVT_XU_F_V: + case RISCV::VFWCVT_X_F_V: + case RISCV::VFWCVT_RTZ_XU_F_V: + case RISCV::VFWCVT_RTZ_X_F_V: + case RISCV::VFWCVT_F_XU_V: + case RISCV::VFWCVT_F_X_V: + case RISCV::VFWCVT_F_F_V: + // Narrowing Floating-Point/Integer Type-Convert Instructions + case RISCV::VFNCVT_XU_F_W: + case RISCV::VFNCVT_X_F_W: + case RISCV::VFNCVT_RTZ_XU_F_W: + case RISCV::VFNCVT_RTZ_X_F_W: + case RISCV::VFNCVT_F_XU_W: + case RISCV::VFNCVT_F_X_W: + case RISCV::VFNCVT_F_F_W: + case RISCV::VFNCVT_ROD_F_F_W: return true; } @@ -1007,6 +1085,11 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return false; } + if (MI.mayRaiseFPException()) { + LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n"); + return false; + } + // Some instructions that produce vectors have semantics that make it more // difficult to determine whether the VL can be reduced. For example, some // instructions, such as reductions, may write lanes past VL to a scalar @@ -1028,79 +1111,103 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return true; } -bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL, - MachineInstr &MI) { +std::optional<MachineOperand> +RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) { + const MachineInstr &UserMI = *UserOp.getParent(); + const MCInstrDesc &Desc = UserMI.getDesc(); + + if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { + LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" + " use VLMAX\n"); + return std::nullopt; + } + + // Instructions like reductions may use a vector register as a scalar + // register. In this case, we should treat it as only reading the first lane. + if (isVectorOpUsedAsScalarOp(UserOp)) { + [[maybe_unused]] Register R = UserOp.getReg(); + [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); + assert(RISCV::VRRegClass.hasSubClassEq(RC) && + "Expect LMUL 1 register class for vector as scalar operands!"); + LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n"); + + return MachineOperand::CreateImm(1); + } + + unsigned VLOpNum = RISCVII::getVLOpNum(Desc); + const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); + // Looking for an immediate or a register VL that isn't X0. + assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && + "Did not expect X0 VL"); + return VLOp; +} + +std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) { // FIXME: Avoid visiting each user for each time we visit something on the // worklist, combined with an extra visit from the outer loop. Restructure // along lines of an instcombine style worklist which integrates the outer // pass. - bool CanReduceVL = true; + std::optional<MachineOperand> CommonVL; for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) { const MachineInstr &UserMI = *UserOp.getParent(); LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); - - // Instructions like reductions may use a vector register as a scalar - // register. In this case, we should treat it like a scalar register which - // does not impact the decision on whether to optimize VL. - // TODO: Treat it like a scalar register instead of bailing out. - if (isVectorOpUsedAsScalarOp(UserOp)) { - CanReduceVL = false; - break; - } - if (mayReadPastVL(UserMI)) { LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); - CanReduceVL = false; - break; + return std::nullopt; } // Tied operands might pass through. if (UserOp.isTied()) { LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n"); - CanReduceVL = false; - break; + return std::nullopt; } - const MCInstrDesc &Desc = UserMI.getDesc(); - if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { - LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that" - " use VLMAX\n"); - CanReduceVL = false; - break; - } - - unsigned VLOpNum = RISCVII::getVLOpNum(Desc); - const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); - - // Looking for an immediate or a register VL that isn't X0. - assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && - "Did not expect X0 VL"); + auto VLOp = getMinimumVLForUser(UserOp); + if (!VLOp) + return std::nullopt; // Use the largest VL among all the users. If we cannot determine this // statically, then we cannot optimize the VL. - if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, VLOp)) { - CommonVL = &VLOp; + if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { + CommonVL = *VLOp; LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); - } else if (!RISCV::isVLKnownLE(VLOp, *CommonVL)) { + } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); - CanReduceVL = false; - break; + return std::nullopt; + } + + if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { + LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); + return std::nullopt; + } + + std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI); + std::optional<OperandInfo> ProducerInfo = + getOperandInfo(MI.getOperand(0), MRI); + if (!ConsumerInfo || !ProducerInfo) { + LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); + LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); + LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); + return std::nullopt; } - // The SEW and LMUL of destination and source registers need to match. - OperandInfo ConsumerInfo = getOperandInfo(UserOp, MRI); - OperandInfo ProducerInfo = getOperandInfo(MI.getOperand(0), MRI); - if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() || - !OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) { - LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown " - "information for EMUL or EEW.\n"); + // If the operand is used as a scalar operand, then the EEW must be + // compatible. Otherwise, the EMUL *and* EEW must be compatible. + bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp); + if ((IsVectorOpUsedAsScalarOp && + !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) || + (!IsVectorOpUsedAsScalarOp && + !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) { + LLVM_DEBUG( + dbgs() + << " Abort due to incompatible information for EMUL or EEW.\n"); LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); - CanReduceVL = false; - break; + return std::nullopt; } } - return CanReduceVL; + + return CommonVL; } bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { @@ -1112,12 +1219,11 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { MachineInstr &MI = *Worklist.pop_back_val(); LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); - const MachineOperand *CommonVL = nullptr; - bool CanReduceVL = true; - if (isVectorRegClass(MI.getOperand(0).getReg(), MRI)) - CanReduceVL = checkUsers(CommonVL, MI); + if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI)) + continue; - if (!CanReduceVL || !CommonVL) + auto CommonVL = checkUsers(MI); + if (!CommonVL) continue; assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index aa83d99..a79e19f 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -20,7 +20,6 @@ add_llvm_target(SPIRVCodeGen SPIRVCallLowering.cpp SPIRVInlineAsmLowering.cpp SPIRVCommandLine.cpp - SPIRVDuplicatesTracker.cpp SPIRVEmitIntrinsics.cpp SPIRVGlobalRegistry.cpp SPIRVInstrInfo.cpp diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 4012bd7..78add92 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -274,7 +274,7 @@ void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) { } void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) { - for (MachineInstr *MI : MAI->getMSInstrs(MSType)) + for (const MachineInstr *MI : MAI->getMSInstrs(MSType)) outputInstruction(MI); } @@ -326,7 +326,7 @@ void SPIRVAsmPrinter::outputOpMemoryModel() { void SPIRVAsmPrinter::outputEntryPoints() { // Find all OpVariable IDs with required StorageClass. DenseSet<Register> InterfaceIDs; - for (MachineInstr *MI : MAI->GlobalVarList) { + for (const MachineInstr *MI : MAI->GlobalVarList) { assert(MI->getOpcode() == SPIRV::OpVariable); auto SC = static_cast<SPIRV::StorageClass::StorageClass>( MI->getOperand(2).getImm()); @@ -336,14 +336,14 @@ void SPIRVAsmPrinter::outputEntryPoints() { // declaring all global variables referenced by the entry point call tree. if (ST->isAtLeastSPIRVVer(VersionTuple(1, 4)) || SC == SPIRV::StorageClass::Input || SC == SPIRV::StorageClass::Output) { - MachineFunction *MF = MI->getMF(); + const MachineFunction *MF = MI->getMF(); Register Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); InterfaceIDs.insert(Reg); } } // Output OpEntryPoints adding interface args to all of them. - for (MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) { + for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) { SPIRVMCInstLower MCInstLowering; MCInst TmpInst; MCInstLowering.lower(MI, TmpInst, MAI); @@ -381,9 +381,8 @@ void SPIRVAsmPrinter::outputGlobalRequirements() { void SPIRVAsmPrinter::outputExtFuncDecls() { // Insert OpFunctionEnd after each declaration. - SmallVectorImpl<MachineInstr *>::iterator - I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(), - E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end(); + auto I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(), + E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end(); for (; I != E; ++I) { outputInstruction(*I); if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction) diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index fa37313f..44b6f5f8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -418,6 +418,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, .addImm(FuncControl) .addUse(GR->getSPIRVTypeID(FuncTy)); GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0)); + GR->addGlobalObject(&F, &MIRBuilder.getMF(), FuncVReg); // Add OpFunctionParameter instructions int i = 0; @@ -431,6 +432,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); if (F.isDeclaration()) GR->add(&Arg, &MIRBuilder.getMF(), ArgReg); + GR->addGlobalObject(&Arg, &MIRBuilder.getMF(), ArgReg); i++; } // Name the function. diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp deleted file mode 100644 index 48df845..0000000 --- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp +++ /dev/null @@ -1,136 +0,0 @@ -//===-- SPIRVDuplicatesTracker.cpp - SPIR-V Duplicates Tracker --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// General infrastructure for keeping track of the values that according to -// the SPIR-V binary layout should be global to the whole module. -// -//===----------------------------------------------------------------------===// - -#include "SPIRVDuplicatesTracker.h" -#include "SPIRVInstrInfo.h" - -#define DEBUG_TYPE "build-dep-graph" - -using namespace llvm; - -template <typename T> -void SPIRVGeneralDuplicatesTracker::prebuildReg2Entry( - SPIRVDuplicatesTracker<T> &DT, SPIRVReg2EntryTy &Reg2Entry, - const SPIRVInstrInfo *TII) { - for (auto &TPair : DT.getAllUses()) { - for (auto &RegPair : TPair.second) { - const MachineFunction *MF = RegPair.first; - Register R = RegPair.second; - MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(R); - if (!MI || (TPair.second.getIsConst() && !TII->isConstantInstr(*MI))) - continue; - Reg2Entry[&MI->getOperand(0)] = &TPair.second; - } - } -} - -void SPIRVGeneralDuplicatesTracker::buildDepsGraph( - std::vector<SPIRV::DTSortableEntry *> &Graph, const SPIRVInstrInfo *TII, - MachineModuleInfo *MMI = nullptr) { - SPIRVReg2EntryTy Reg2Entry; - prebuildReg2Entry(TT, Reg2Entry, TII); - prebuildReg2Entry(CT, Reg2Entry, TII); - prebuildReg2Entry(GT, Reg2Entry, TII); - prebuildReg2Entry(FT, Reg2Entry, TII); - prebuildReg2Entry(AT, Reg2Entry, TII); - prebuildReg2Entry(MT, Reg2Entry, TII); - prebuildReg2Entry(ST, Reg2Entry, TII); - - for (auto &Op2E : Reg2Entry) { - SPIRV::DTSortableEntry *E = Op2E.second; - Graph.push_back(E); - for (auto &U : *E) { - const MachineRegisterInfo &MRI = U.first->getRegInfo(); - MachineInstr *MI = MRI.getUniqueVRegDef(U.second); - if (!MI) - continue; - assert(MI && MI->getParent() && "No MachineInstr created yet"); - for (auto i = MI->getNumDefs(); i < MI->getNumOperands(); i++) { - MachineOperand &Op = MI->getOperand(i); - if (!Op.isReg()) - continue; - MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg()); - // References to a function via function pointers generate virtual - // registers without a definition. We are able to resolve this - // reference using Globar Register info into an OpFunction instruction - // but do not expect to find it in Reg2Entry. - if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2) - continue; - MachineOperand *RegOp = &VRegDef->getOperand(0); - if (Reg2Entry.count(RegOp) == 0 && - (MI->getOpcode() != SPIRV::OpVariable || i != 3)) { - // try to repair the unexpected code pattern - bool IsFixed = false; - if (VRegDef->getOpcode() == TargetOpcode::G_CONSTANT && - RegOp->isReg() && MRI.getType(RegOp->getReg()).isScalar()) { - const Constant *C = VRegDef->getOperand(1).getCImm(); - add(C, MI->getParent()->getParent(), RegOp->getReg()); - auto Iter = CT.Storage.find(C); - if (Iter != CT.Storage.end()) { - SPIRV::DTSortableEntry &MissedEntry = Iter->second; - Reg2Entry[RegOp] = &MissedEntry; - IsFixed = true; - } - } - if (!IsFixed) { - std::string DiagMsg; - raw_string_ostream OS(DiagMsg); - OS << "Unexpected pattern while building a dependency " - "graph.\nInstruction: "; - MI->print(OS); - OS << "Operand: "; - Op.print(OS); - OS << "\nOperand definition: "; - VRegDef->print(OS); - report_fatal_error(DiagMsg.c_str()); - } - } - if (Reg2Entry.count(RegOp)) - E->addDep(Reg2Entry[RegOp]); - } - - if (E->getIsFunc()) { - MachineInstr *Next = MI->getNextNode(); - if (Next && (Next->getOpcode() == SPIRV::OpFunction || - Next->getOpcode() == SPIRV::OpFunctionParameter)) { - E->addDep(Reg2Entry[&Next->getOperand(0)]); - } - } - } - } - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) - if (MMI) { - const Module *M = MMI->getModule(); - for (auto F = M->begin(), E = M->end(); F != E; ++F) { - const MachineFunction *MF = MMI->getMachineFunction(*F); - if (!MF) - continue; - for (const MachineBasicBlock &MBB : *MF) { - for (const MachineInstr &CMI : MBB) { - MachineInstr &MI = const_cast<MachineInstr &>(CMI); - MI.dump(); - if (MI.getNumExplicitDefs() > 0 && - Reg2Entry.count(&MI.getOperand(0))) { - dbgs() << "\t["; - for (SPIRV::DTSortableEntry *D : - Reg2Entry.lookup(&MI.getOperand(0))->getDeps()) - dbgs() << Register::virtReg2Index(D->lookup(MF)) << ", "; - dbgs() << "]\n"; - } - } - } - } - } -#endif -} diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h index 6847da0..e574892 100644 --- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h +++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h @@ -211,23 +211,7 @@ class SPIRVGeneralDuplicatesTracker { SPIRVDuplicatesTracker<MachineInstr> MT; SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST; - // NOTE: using MOs instead of regs to get rid of MF dependency to be able - // to use flat data structure. - // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness - // but makes LITs more stable, should prefer DenseMap still due to - // significant perf difference. - using SPIRVReg2EntryTy = - MapVector<MachineOperand *, SPIRV::DTSortableEntry *>; - - template <typename T> - void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT, - SPIRVReg2EntryTy &Reg2Entry, - const SPIRVInstrInfo *TII); - public: - void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph, - const SPIRVInstrInfo *TII, MachineModuleInfo *MMI); - void add(const Type *Ty, const MachineFunction *MF, Register R) { TT.add(unifyPtrType(Ty), MF, R); } diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 77b5421..d2b14d6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -1841,20 +1841,20 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV, // Skip special artifical variable llvm.global.annotations. if (GV.getName() == "llvm.global.annotations") return; - if (GV.hasInitializer() && !isa<UndefValue>(GV.getInitializer())) { + Constant *Init = nullptr; + if (hasInitializer(&GV)) { // Deduce element type and store results in Global Registry. // Result is ignored, because TypedPointerType is not supported // by llvm IR general logic. deduceElementTypeHelper(&GV, false); - Constant *Init = GV.getInitializer(); + Init = GV.getInitializer(); Type *Ty = isAggrConstForceInt32(Init) ? B.getInt32Ty() : Init->getType(); Constant *Const = isAggrConstForceInt32(Init) ? B.getInt32(1) : Init; auto *InitInst = B.CreateIntrinsic(Intrinsic::spv_init_global, {GV.getType(), Ty}, {&GV, Const}); InitInst->setArgOperand(1, Init); } - if ((!GV.hasInitializer() || isa<UndefValue>(GV.getInitializer())) && - GV.getNumUses() == 0) + if (!Init && GV.getNumUses() == 0) B.CreateIntrinsic(Intrinsic::spv_unref_global, GV.getType(), &GV); } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 0c424477..a06c62e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -721,6 +721,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable( } Reg = MIB->getOperand(0).getReg(); DT.add(GVar, &MIRBuilder.getMF(), Reg); + addGlobalObject(GVar, &MIRBuilder.getMF(), Reg); // Set to Reg the same type as ResVReg has. auto MRI = MIRBuilder.getMRI(); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index ec2386fa..528baf5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -89,6 +89,9 @@ class SPIRVGlobalRegistry { // Intrinsic::spv_assign_ptr_type instructions. DenseMap<Value *, CallInst *> AssignPtrTypeInstr; + // Maps OpVariable and OpFunction-related v-regs to its LLVM IR definition. + DenseMap<std::pair<const MachineFunction *, Register>, const Value *> Reg2GO; + // Add a new OpTypeXXX instruction without checking for duplicates. SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ = @@ -151,15 +154,17 @@ public: return DT.find(F, MF); } - void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph, - const SPIRVInstrInfo *TII, - MachineModuleInfo *MMI = nullptr) { - DT.buildDepsGraph(Graph, TII, MMI); - } - void setBound(unsigned V) { Bound = V; } unsigned getBound() { return Bound; } + void addGlobalObject(const Value *V, const MachineFunction *MF, Register R) { + Reg2GO[std::make_pair(MF, R)] = V; + } + const Value *getGlobalObject(const MachineFunction *MF, Register R) { + auto It = Reg2GO.find(std::make_pair(MF, R)); + return It == Reg2GO.end() ? nullptr : It->second; + } + // Add a record to the map of function return pointer types. void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) { FunResPointerTypes[ArgF] = DerivedTy; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index bd9e77e..9a140e7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -47,6 +47,19 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const { } } +bool SPIRVInstrInfo::isSpecConstantInstr(const MachineInstr &MI) const { + switch (MI.getOpcode()) { + case SPIRV::OpSpecConstantTrue: + case SPIRV::OpSpecConstantFalse: + case SPIRV::OpSpecConstant: + case SPIRV::OpSpecConstantComposite: + case SPIRV::OpSpecConstantOp: + return true; + default: + return false; + } +} + bool SPIRVInstrInfo::isInlineAsmDefInstr(const MachineInstr &MI) const { switch (MI.getOpcode()) { case SPIRV::OpAsmTargetINTEL: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h index 67d2d97..4e5059b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h @@ -30,6 +30,7 @@ public: const SPIRVRegisterInfo &getRegisterInfo() const { return RI; } bool isHeaderInstr(const MachineInstr &MI) const; bool isConstantInstr(const MachineInstr &MI) const; + bool isSpecConstantInstr(const MachineInstr &MI) const; bool isInlineAsmDefInstr(const MachineInstr &MI) const; bool isTypeDeclInstr(const MachineInstr &MI) const; bool isDecorationInstr(const MachineInstr &MI) const; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 4815685..28c9b81 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -33,7 +33,6 @@ #include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/IR/IntrinsicsSPIRV.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "spirv-isel" @@ -46,17 +45,6 @@ using ExtInstList = namespace { -llvm::SPIRV::SelectionControl::SelectionControl -getSelectionOperandForImm(int Imm) { - if (Imm == 2) - return SPIRV::SelectionControl::Flatten; - if (Imm == 1) - return SPIRV::SelectionControl::DontFlatten; - if (Imm == 0) - return SPIRV::SelectionControl::None; - llvm_unreachable("Invalid immediate"); -} - #define GET_GLOBALISEL_PREDICATE_BITSET #include "SPIRVGenGlobalISel.inc" #undef GET_GLOBALISEL_PREDICATE_BITSET @@ -1117,6 +1105,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg, Constant::getNullValue(LLVMArrTy)); Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); GR.add(GV, GR.CurMF, VarReg); + GR.addGlobalObject(GV, GR.CurMF, VarReg); Result &= BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable)) @@ -2829,8 +2818,12 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, } return MIB.constrainAllUses(TII, TRI, RBI); } - case Intrinsic::spv_loop_merge: { - auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge)); + case Intrinsic::spv_loop_merge: + case Intrinsic::spv_selection_merge: { + const auto Opcode = IID == Intrinsic::spv_selection_merge + ? SPIRV::OpSelectionMerge + : SPIRV::OpLoopMerge; + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)); for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) { assert(I.getOperand(i).isMBB()); MIB.addMBB(I.getOperand(i).getMBB()); @@ -2838,15 +2831,6 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, MIB.addImm(SPIRV::SelectionControl::None); return MIB.constrainAllUses(TII, TRI, RBI); } - case Intrinsic::spv_selection_merge: { - auto MIB = - BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge)); - assert(I.getOperand(1).isMBB() && - "operand 1 to spv_selection_merge must be a basic block"); - MIB.addMBB(I.getOperand(1).getMBB()); - MIB.addImm(getSelectionOperandForImm(I.getOperand(2).getImm())); - return MIB.constrainAllUses(TII, TRI, RBI); - } case Intrinsic::spv_cmpxchg: return selectAtomicCmpXchg(ResVReg, ResType, I); case Intrinsic::spv_unreachable: @@ -3477,7 +3461,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( ID = UnnamedGlobalIDs.size(); GlobalIdent = "__unnamed_" + Twine(ID).str(); } else { - GlobalIdent = GV->getGlobalIdentifier(); + GlobalIdent = GV->getName(); } // Behaviour of functions as operands depends on availability of the @@ -3509,18 +3493,25 @@ bool SPIRVInstructionSelector::selectGlobalValue( // References to a function via function pointers generate virtual // registers without a definition. We will resolve it later, during // module analysis stage. + Register ResTypeReg = GR.getSPIRVTypeID(ResType); MachineRegisterInfo *MRI = MIRBuilder.getMRI(); - Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); - MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass); - MachineInstrBuilder MB = + Register FuncVReg = + MRI->createGenericVirtualRegister(GR.getRegType(ResType)); + MRI->setRegClass(FuncVReg, &SPIRV::pIDRegClass); + MachineInstrBuilder MIB1 = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) + .addDef(FuncVReg) + .addUse(ResTypeReg); + MachineInstrBuilder MIB2 = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantFunctionPointerINTEL)) .addDef(NewReg) - .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(ResTypeReg) .addUse(FuncVReg); // mapping the function pointer to the used Function - GR.recordFunctionPointer(&MB.getInstr()->getOperand(2), GVFun); - return MB.constrainAllUses(TII, TRI, RBI); + GR.recordFunctionPointer(&MIB2.getInstr()->getOperand(2), GVFun); + return MIB1.constrainAllUses(TII, TRI, RBI) && + MIB2.constrainAllUses(TII, TRI, RBI); } return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) .addDef(NewReg) @@ -3533,18 +3524,16 @@ bool SPIRVInstructionSelector::selectGlobalValue( auto GlobalVar = cast<GlobalVariable>(GV); assert(GlobalVar->getName() != "llvm.global.annotations"); - bool HasInit = GlobalVar->hasInitializer() && - !isa<UndefValue>(GlobalVar->getInitializer()); - // Skip empty declaration for GVs with initilaizers till we get the decl with + // Skip empty declaration for GVs with initializers till we get the decl with // passed initializer. - if (HasInit && !Init) + if (hasInitializer(GlobalVar) && !Init) return true; - bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage; + bool HasLnkTy = !GV->hasInternalLinkage() && !GV->hasPrivateLinkage(); SPIRV::LinkageType::LinkageType LnkType = - (GV->isDeclaration() || GV->hasAvailableExternallyLinkage()) + GV->isDeclarationForLinker() ? SPIRV::LinkageType::Import - : (GV->getLinkage() == GlobalValue::LinkOnceODRLinkage && + : (GV->hasLinkOnceODRLinkage() && STI.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr) ? SPIRV::LinkageType::LinkOnceODR : SPIRV::LinkageType::Export); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 6371c67..63adf54 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -216,102 +216,262 @@ void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { } } -// Collect MI which defines the register in the given machine function. -static void collectDefInstr(Register Reg, const MachineFunction *MF, - SPIRV::ModuleAnalysisInfo *MAI, - SPIRV::ModuleSectionType MSType, - bool DoInsert = true) { - assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); - MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); - assert(MI && "There should be an instruction that defines the register"); - MAI->setSkipEmission(MI); - if (DoInsert) - MAI->MS[MSType].push_back(MI); +// Returns a representation of an instruction as a vector of MachineOperand +// hash values, see llvm::hash_value(const MachineOperand &MO) for details. +// This creates a signature of the instruction with the same content +// that MachineOperand::isIdenticalTo uses for comparison. +static InstrSignature instrToSignature(const MachineInstr &MI, + SPIRV::ModuleAnalysisInfo &MAI, + bool UseDefReg) { + InstrSignature Signature{MI.getOpcode()}; + for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + const MachineOperand &MO = MI.getOperand(i); + size_t h; + if (MO.isReg()) { + if (!UseDefReg && MO.isDef()) + continue; + Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); + if (!RegAlias.isValid()) { + LLVM_DEBUG({ + dbgs() << "Unexpectedly, no global id found for the operand "; + MO.print(dbgs()); + dbgs() << "\nInstruction: "; + MI.print(dbgs()); + dbgs() << "\n"; + }); + report_fatal_error("All v-regs must have been mapped to global id's"); + } + // mimic llvm::hash_value(const MachineOperand &MO) + h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), + MO.isDef()); + } else { + h = hash_value(MO); + } + Signature.push_back(h); + } + return Signature; } -void SPIRVModuleAnalysis::collectGlobalEntities( - const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, - SPIRV::ModuleSectionType MSType, - std::function<bool(const SPIRV::DTSortableEntry *)> Pred, - bool UsePreOrder = false) { - DenseSet<const SPIRV::DTSortableEntry *> Visited; - for (const auto *E : DepsGraph) { - std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; - // NOTE: here we prefer recursive approach over iterative because - // we don't expect depchains long enough to cause SO. - RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, - &RecHoistUtil](const SPIRV::DTSortableEntry *E) { - if (Visited.count(E) || !Pred(E)) - return; - Visited.insert(E); - - // Traversing deps graph in post-order allows us to get rid of - // register aliases preprocessing. - // But pre-order is required for correct processing of function - // declaration and arguments processing. - if (!UsePreOrder) - for (auto *S : E->getDeps()) - RecHoistUtil(S); - - Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); - bool IsFirst = true; - for (auto &U : *E) { - const MachineFunction *MF = U.first; - Register Reg = U.second; - MAI.setRegisterAlias(MF, Reg, GlobalReg); - if (!MF->getRegInfo().getUniqueVRegDef(Reg)) - continue; - collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); - IsFirst = false; - if (E->getIsGV()) - MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); - } +bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI, + const MachineInstr &MI) { + unsigned Opcode = MI.getOpcode(); + switch (Opcode) { + case SPIRV::OpTypeForwardPointer: + // omit now, collect later + return false; + case SPIRV::OpVariable: + return static_cast<SPIRV::StorageClass::StorageClass>( + MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function; + case SPIRV::OpFunction: + case SPIRV::OpFunctionParameter: + return true; + } + if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) { + Register DefReg = MI.getOperand(0).getReg(); + for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) { + if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL) + continue; + // it's a dummy definition, FP constant refers to a function, + // and this is resolved in another way; let's skip this definition + assert(UseMI.getOperand(2).isReg() && + UseMI.getOperand(2).getReg() == DefReg); + MAI.setSkipEmission(&MI); + return false; + } + } + return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) || + TII->isInlineAsmDefInstr(MI); +} - if (UsePreOrder) - for (auto *S : E->getDeps()) - RecHoistUtil(S); - }; - RecHoistUtil(E); +// This is a special case of a function pointer refering to a possibly +// forward function declaration. The operand is a dummy OpUndef that +// requires a special treatment. +void SPIRVModuleAnalysis::visitFunPtrUse( + Register OpReg, InstrGRegsMap &SignatureToGReg, + std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF, + const MachineInstr &MI) { + const MachineOperand *OpFunDef = + GR->getFunctionDefinitionByUse(&MI.getOperand(2)); + assert(OpFunDef && OpFunDef->isReg()); + // find the actual function definition and number it globally in advance + const MachineInstr *OpDefMI = OpFunDef->getParent(); + assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction); + const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent(); + const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo(); + do { + visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI); + OpDefMI = OpDefMI->getNextNode(); + } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction || + OpDefMI->getOpcode() == SPIRV::OpFunctionParameter)); + // associate the function pointer with the newly assigned global number + Register GlobalFunDefReg = MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg()); + assert(GlobalFunDefReg.isValid() && + "Function definition must refer to a global register"); + MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg); +} + +// Depth first recursive traversal of dependencies. Repeated visits are guarded +// by MAI.hasRegisterAlias(). +void SPIRVModuleAnalysis::visitDecl( + const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg, + std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF, + const MachineInstr &MI) { + unsigned Opcode = MI.getOpcode(); + DenseSet<Register> Deps; + + // Process each operand of the instruction to resolve dependencies + for (const MachineOperand &MO : MI.operands()) { + if (!MO.isReg() || MO.isDef()) + continue; + Register OpReg = MO.getReg(); + // Handle function pointers special case + if (Opcode == SPIRV::OpConstantFunctionPointerINTEL && + MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) { + visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI); + continue; + } + // Skip already processed instructions + if (MAI.hasRegisterAlias(MF, MO.getReg())) + continue; + // Recursively visit dependencies + if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) { + if (isDeclSection(MRI, *OpDefMI)) + visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI); + continue; + } + // Handle the unexpected case of no unique definition for the SPIR-V + // instruction + LLVM_DEBUG({ + dbgs() << "Unexpectedly, no unique definition for the operand "; + MO.print(dbgs()); + dbgs() << "\nInstruction: "; + MI.print(dbgs()); + dbgs() << "\n"; + }); + report_fatal_error( + "No unique definition is found for the virtual register"); } + + Register GReg; + bool IsFunDef = false; + if (TII->isSpecConstantInstr(MI)) { + GReg = Register::index2VirtReg(MAI.getNextID()); + MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI); + } else if (Opcode == SPIRV::OpFunction || + Opcode == SPIRV::OpFunctionParameter) { + GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef); + } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) || + TII->isInlineAsmDefInstr(MI)) { + GReg = handleTypeDeclOrConstant(MI, SignatureToGReg); + } else if (Opcode == SPIRV::OpVariable) { + GReg = handleVariable(MF, MI, GlobalToGReg); + } else { + LLVM_DEBUG({ + dbgs() << "\nInstruction: "; + MI.print(dbgs()); + dbgs() << "\n"; + }); + llvm_unreachable("Unexpected instruction is visited"); + } + MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg); + if (!IsFunDef) + MAI.setSkipEmission(&MI); } -// The function initializes global register alias table for types, consts, -// global vars and func decls and collects these instruction for output -// at module level. Also it collects explicit OpExtension/OpCapability -// instructions. -void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { - std::vector<SPIRV::DTSortableEntry *> DepsGraph; +Register SPIRVModuleAnalysis::handleFunctionOrParameter( + const MachineFunction *MF, const MachineInstr &MI, + std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) { + const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg()); + assert(GObj && "Unregistered global definition"); + const Function *F = dyn_cast<Function>(GObj); + if (!F) + F = dyn_cast<Argument>(GObj)->getParent(); + assert(F && "Expected a reference to a function or an argument"); + IsFunDef = !F->isDeclaration(); + auto It = GlobalToGReg.find(GObj); + if (It != GlobalToGReg.end()) + return It->second; + Register GReg = Register::index2VirtReg(MAI.getNextID()); + GlobalToGReg[GObj] = GReg; + if (!IsFunDef) + MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI); + return GReg; +} - GR->buildDepsGraph(DepsGraph, TII, SPVDumpDeps ? MMI : nullptr); +Register +SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI, + InstrGRegsMap &SignatureToGReg) { + InstrSignature MISign = instrToSignature(MI, MAI, false); + auto It = SignatureToGReg.find(MISign); + if (It != SignatureToGReg.end()) + return It->second; + Register GReg = Register::index2VirtReg(MAI.getNextID()); + SignatureToGReg[MISign] = GReg; + MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI); + return GReg; +} - collectGlobalEntities( - DepsGraph, SPIRV::MB_TypeConstVars, - [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); +Register SPIRVModuleAnalysis::handleVariable( + const MachineFunction *MF, const MachineInstr &MI, + std::map<const Value *, unsigned> &GlobalToGReg) { + MAI.GlobalVarList.push_back(&MI); + const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg()); + assert(GObj && "Unregistered global definition"); + auto It = GlobalToGReg.find(GObj); + if (It != GlobalToGReg.end()) + return It->second; + Register GReg = Register::index2VirtReg(MAI.getNextID()); + GlobalToGReg[GObj] = GReg; + MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI); + return GReg; +} +void SPIRVModuleAnalysis::collectDeclarations(const Module &M) { + InstrGRegsMap SignatureToGReg; + std::map<const Value *, unsigned> GlobalToGReg; for (auto F = M.begin(), E = M.end(); F != E; ++F) { MachineFunction *MF = MMI->getMachineFunction(*F); if (!MF) continue; - // Iterate through and collect OpExtension/OpCapability instructions. + const MachineRegisterInfo &MRI = MF->getRegInfo(); + unsigned PastHeader = 0; for (MachineBasicBlock &MBB : *MF) { for (MachineInstr &MI : MBB) { - if (MI.getOpcode() == SPIRV::OpExtension) { - // Here, OpExtension just has a single enum operand, not a string. - auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); - MAI.Reqs.addExtension(Ext); + if (MI.getNumOperands() == 0) + continue; + unsigned Opcode = MI.getOpcode(); + if (Opcode == SPIRV::OpFunction) { + if (PastHeader == 0) { + PastHeader = 1; + continue; + } + } else if (Opcode == SPIRV::OpFunctionParameter) { + if (PastHeader < 2) + continue; + } else if (PastHeader > 0) { + PastHeader = 2; + } + + const MachineOperand &DefMO = MI.getOperand(0); + switch (Opcode) { + case SPIRV::OpExtension: + MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm())); MAI.setSkipEmission(&MI); - } else if (MI.getOpcode() == SPIRV::OpCapability) { - auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); - MAI.Reqs.addCapability(Cap); + break; + case SPIRV::OpCapability: + MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm())); MAI.setSkipEmission(&MI); + if (PastHeader > 0) + PastHeader = 2; + break; + default: + if (DefMO.isReg() && isDeclSection(MRI, MI) && + !MAI.hasRegisterAlias(MF, DefMO.getReg())) + visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI); } } } } - - collectGlobalEntities( - DepsGraph, SPIRV::MB_ExtFuncDecls, - [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); } // Look for IDs declared with Import linkage, and map the corresponding function @@ -342,58 +502,6 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, } } -// References to a function via function pointers generate virtual -// registers without a definition. We are able to resolve this -// reference using Globar Register info into an OpFunction instruction -// and replace dummy operands by the corresponding global register references. -void SPIRVModuleAnalysis::collectFuncPtrs() { - for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars]) - if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL) - collectFuncPtrs(MI); -} - -void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) { - const MachineOperand *FunUse = &MI->getOperand(2); - if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) { - const MachineInstr *FunDefMI = FunDef->getParent(); - assert(FunDefMI->getOpcode() == SPIRV::OpFunction && - "Constant function pointer must refer to function definition"); - Register FunDefReg = FunDef->getReg(); - Register GlobalFunDefReg = - MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg); - assert(GlobalFunDefReg.isValid() && - "Function definition must refer to a global register"); - Register FunPtrReg = FunUse->getReg(); - MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg); - } -} - -using InstrSignature = SmallVector<size_t>; -using InstrTraces = std::set<InstrSignature>; - -// Returns a representation of an instruction as a vector of MachineOperand -// hash values, see llvm::hash_value(const MachineOperand &MO) for details. -// This creates a signature of the instruction with the same content -// that MachineOperand::isIdenticalTo uses for comparison. -static InstrSignature instrToSignature(MachineInstr &MI, - SPIRV::ModuleAnalysisInfo &MAI) { - InstrSignature Signature; - for (unsigned i = 0; i < MI.getNumOperands(); ++i) { - const MachineOperand &MO = MI.getOperand(i); - size_t h; - if (MO.isReg()) { - Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); - // mimic llvm::hash_value(const MachineOperand &MO) - h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), - MO.isDef()); - } else { - h = hash_value(MO); - } - Signature.push_back(h); - } - return Signature; -} - // Collect the given instruction in the specified MS. We assume global register // numbering has already occurred by this point. We can directly compare reg // arguments when detecting duplicates. @@ -401,7 +509,7 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, SPIRV::ModuleSectionType MSType, InstrTraces &IS, bool Append = true) { MAI.setSkipEmission(&MI); - InstrSignature MISign = instrToSignature(MI, MAI); + InstrSignature MISign = instrToSignature(MI, MAI, true); auto FoundMI = IS.insert(MISign); if (!FoundMI.second) return; // insert failed, so we found a duplicate; don't add it to MAI.MS @@ -465,7 +573,7 @@ void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { // Number registers in all functions globally from 0 onwards and store // the result in global register alias table. Some registers are already -// numbered in collectGlobalEntities. +// numbered. void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { for (auto F = M.begin(), E = M.end(); F != E; ++F) { if ((*F).isDeclaration()) @@ -1835,15 +1943,11 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) { // Process type/const/global var/func decl instructions, number their // destination registers from 0 to N, collect Extensions and Capabilities. - processDefInstrs(M); + collectDeclarations(M); // Number rest of registers from N+1 onwards. numberRegistersGlobally(M); - // Update references to OpFunction instructions to use Global Registers - if (GR->hasConstFunPtr()) - collectFuncPtrs(); - // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. processOtherInstrs(M); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h index ee2aaf1..79b5444 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -124,7 +124,7 @@ public: const Capability::Capability IfPresent); }; -using InstrList = SmallVector<MachineInstr *>; +using InstrList = SmallVector<const MachineInstr *>; // Maps a local register to the corresponding global alias. using LocalToGlobalRegTable = std::map<Register, Register>; using RegisterAliasMapTy = @@ -142,12 +142,12 @@ struct ModuleAnalysisInfo { // Maps ExtInstSet to corresponding ID register. DenseMap<unsigned, Register> ExtInstSetMap; // Contains the list of all global OpVariables in the module. - SmallVector<MachineInstr *, 4> GlobalVarList; + SmallVector<const MachineInstr *, 4> GlobalVarList; // Maps functions to corresponding function ID registers. DenseMap<const Function *, Register> FuncMap; // The set contains machine instructions which are necessary // for correct MIR but will not be emitted in function bodies. - DenseSet<MachineInstr *> InstrsToDelete; + DenseSet<const MachineInstr *> InstrsToDelete; // The table contains global aliases of local registers for each machine // function. The aliases are used to substitute local registers during // code emission. @@ -167,7 +167,7 @@ struct ModuleAnalysisInfo { } Register getExtInstSetReg(unsigned SetNum) { return ExtInstSetMap[SetNum]; } InstrList &getMSInstrs(unsigned MSType) { return MS[MSType]; } - void setSkipEmission(MachineInstr *MI) { InstrsToDelete.insert(MI); } + void setSkipEmission(const MachineInstr *MI) { InstrsToDelete.insert(MI); } bool getSkipEmission(const MachineInstr *MI) { return InstrsToDelete.contains(MI); } @@ -204,6 +204,10 @@ struct ModuleAnalysisInfo { }; } // namespace SPIRV +using InstrSignature = SmallVector<size_t>; +using InstrTraces = std::set<InstrSignature>; +using InstrGRegsMap = std::map<SmallVector<size_t>, unsigned>; + struct SPIRVModuleAnalysis : public ModulePass { static char ID; @@ -216,17 +220,27 @@ public: private: void setBaseInfo(const Module &M); - void collectGlobalEntities( - const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, - SPIRV::ModuleSectionType MSType, - std::function<bool(const SPIRV::DTSortableEntry *)> Pred, - bool UsePreOrder); - void processDefInstrs(const Module &M); void collectFuncNames(MachineInstr &MI, const Function *F); void processOtherInstrs(const Module &M); void numberRegistersGlobally(const Module &M); - void collectFuncPtrs(); - void collectFuncPtrs(MachineInstr *MI); + + // analyze dependencies to collect module scope definitions + void collectDeclarations(const Module &M); + void visitDecl(const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg, + std::map<const Value *, unsigned> &GlobalToGReg, + const MachineFunction *MF, const MachineInstr &MI); + Register handleVariable(const MachineFunction *MF, const MachineInstr &MI, + std::map<const Value *, unsigned> &GlobalToGReg); + Register handleTypeDeclOrConstant(const MachineInstr &MI, + InstrGRegsMap &SignatureToGReg); + Register + handleFunctionOrParameter(const MachineFunction *MF, const MachineInstr &MI, + std::map<const Value *, unsigned> &GlobalToGReg, + bool &IsFunDef); + void visitFunPtrUse(Register OpReg, InstrGRegsMap &SignatureToGReg, + std::map<const Value *, unsigned> &GlobalToGReg, + const MachineFunction *MF, const MachineInstr &MI); + bool isDeclSection(const MachineRegisterInfo &MRI, const MachineInstr &MI); const SPIRVSubtarget *ST; SPIRVGlobalRegistry *GR; diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 8357c30..5b4c849 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -58,9 +58,10 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR, ->getValue()); if (auto *GV = dyn_cast<GlobalValue>(Const)) { Register Reg = GR->find(GV, &MF); - if (!Reg.isValid()) + if (!Reg.isValid()) { GR->add(GV, &MF, SrcReg); - else + GR->addGlobalObject(GV, &MF, SrcReg); + } else RegsAlreadyAddedToDT[&MI] = Reg; } else { Register Reg = GR->find(Const, &MF); diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp index 2e4343c..336cde4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp @@ -18,16 +18,14 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/CodeGen/IntrinsicLowering.h" +#include "llvm/IR/Analysis.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsSPIRV.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/InitializePasses.h" -#include "llvm/PassRegistry.h" -#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" @@ -648,7 +646,8 @@ class SPIRVStructurizer : public FunctionPass { Builder.SetInsertPoint(Header->getTerminator()); auto MergeAddress = BlockAddress::get(BB.getParent(), &BB); - createOpSelectMerge(&Builder, MergeAddress); + SmallVector<Value *, 1> Args = {MergeAddress}; + Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args}); Modified = true; } @@ -770,9 +769,10 @@ class SPIRVStructurizer : public FunctionPass { BasicBlock *Merge = Candidates[0]; auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge); + SmallVector<Value *, 1> Args = {MergeAddress}; IRBuilder<> Builder(&BB); Builder.SetInsertPoint(BB.getTerminator()); - createOpSelectMerge(&Builder, MergeAddress); + Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args}); } return Modified; @@ -1105,7 +1105,8 @@ class SPIRVStructurizer : public FunctionPass { Builder.SetInsertPoint(Header->getTerminator()); auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge); - createOpSelectMerge(&Builder, MergeAddress); + SmallVector<Value *, 1> Args = {MergeAddress}; + Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args}); continue; } @@ -1119,7 +1120,8 @@ class SPIRVStructurizer : public FunctionPass { Builder.SetInsertPoint(Header->getTerminator()); auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge); - createOpSelectMerge(&Builder, MergeAddress); + SmallVector<Value *, 1> Args = {MergeAddress}; + Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args}); } return Modified; @@ -1206,27 +1208,6 @@ public: AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>(); FunctionPass::getAnalysisUsage(AU); } - - void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) { - Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator(); - - MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); - - ConstantInt *BranchHint = llvm::ConstantInt::get(Builder->getInt32Ty(), 0); - - if (MDNode) { - assert(MDNode->getNumOperands() == 2 && - "invalid metadata hlsl.controlflow.hint"); - BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1)); - - assert(BranchHint && "invalid metadata value for hlsl.controlflow.hint"); - } - - llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, BranchHint}; - - Builder->CreateIntrinsic(Intrinsic::spv_selection_merge, - {MergeAddress->getType()}, {Args}); - } }; } // namespace llvm @@ -1248,11 +1229,8 @@ FunctionPass *llvm::createSPIRVStructurizerPass() { PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F, FunctionAnalysisManager &AF) { - - auto FPM = legacy::FunctionPassManager(F.getParent()); - FPM.add(createSPIRVStructurizerPass()); - - if (!FPM.run(F)) + FunctionPass *StructurizerPass = createSPIRVStructurizerPass(); + if (!StructurizerPass->runOnFunction(F)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index da2e24c..60649ea 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -17,6 +17,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/TypedPointerType.h" #include <queue> @@ -236,6 +237,10 @@ Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx); // Returns true if the function was changed. bool sortBlocks(Function &F); +inline bool hasInitializer(const GlobalVariable *GV) { + return GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer()); +} + // True if this is an instance of TypedPointerType. inline bool isTypedPointerTy(const Type *T) { return T && T->getTypeID() == Type::TypedPointerTyID; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 68bdeb1..6b0eb38 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -94,7 +94,7 @@ static cl::opt<int> BrMergingCcmpBias( static cl::opt<bool> WidenShift("x86-widen-shift", cl::init(true), - cl::desc("Replacte narrow shifts with wider shifts."), + cl::desc("Replace narrow shifts with wider shifts."), cl::Hidden); static cl::opt<int> BrMergingLikelyBias( @@ -41694,6 +41694,8 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { MVT VT = N.getSimpleValueType(); + unsigned NumElts = VT.getVectorNumElements(); + SmallVector<int, 4> Mask; unsigned Opcode = N.getOpcode(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -41979,7 +41981,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, APInt Mask = APInt::getHighBitsSet(64, 32); if (DAG.MaskedValueIsZero(In, Mask)) { SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, In); - MVT VecVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() * 2); + MVT VecVT = MVT::getVectorVT(MVT::i32, NumElts * 2); SDValue SclVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Trunc); SDValue Movl = DAG.getNode(X86ISD::VZEXT_MOVL, DL, VecVT, SclVec); return DAG.getBitcast(VT, Movl); @@ -41994,7 +41996,6 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, // Create a vector constant - scalar constant followed by zeros. EVT ScalarVT = N0.getOperand(0).getValueType(); Type *ScalarTy = ScalarVT.getTypeForEVT(*DAG.getContext()); - unsigned NumElts = VT.getVectorNumElements(); Constant *Zero = ConstantInt::getNullValue(ScalarTy); SmallVector<Constant *, 32> ConstantVec(NumElts, Zero); ConstantVec[0] = const_cast<ConstantInt *>(C->getConstantIntValue()); @@ -42045,9 +42046,8 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, MVT SrcVT = N0.getOperand(0).getSimpleValueType(); unsigned SrcBits = SrcVT.getScalarSizeInBits(); if ((EltBits % SrcBits) == 0 && SrcBits >= 32) { - unsigned Size = VT.getVectorNumElements(); unsigned NewSize = SrcVT.getVectorNumElements(); - APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(Size); + APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(NumElts); APInt NewBlendMask = APIntOps::ScaleBitMask(BlendMask, NewSize); return DAG.getBitcast( VT, DAG.getNode(X86ISD::BLENDI, DL, SrcVT, N0.getOperand(0), @@ -42460,7 +42460,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, int DOffset = N.getOpcode() == X86ISD::PSHUFLW ? 0 : 2; DMask[DOffset + 0] = DOffset + 1; DMask[DOffset + 1] = DOffset + 0; - MVT DVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2); + MVT DVT = MVT::getVectorVT(MVT::i32, NumElts / 2); V = DAG.getBitcast(DVT, V); V = DAG.getNode(X86ISD::PSHUFD, DL, DVT, V, getV4X86ShuffleImm8ForMask(DMask, DL, DAG)); diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 7a7554c..c19bcfc 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -1650,6 +1650,13 @@ InstructionCost X86TTIImpl::getShuffleCost( return MatchingTypes ? TTI::TCC_Free : SubLT.first; } + // Attempt to match MOVSS (Idx == 0) or INSERTPS pattern. This will have + // been matched by improveShuffleKindFromMask as a SK_InsertSubvector of + // v1f32 (legalised to f32) into a v4f32. + if (LT.first == 1 && LT.second == MVT::v4f32 && SubLT.first == 1 && + SubLT.second == MVT::f32 && (Index == 0 || ST->hasSSE41())) + return 1; + // If the insertion isn't aligned, treat it like a 2-op shuffle. Kind = TTI::SK_PermuteTwoSrc; } @@ -4797,9 +4804,12 @@ InstructionCost X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, MVT MScalarTy = LT.second.getScalarType(); auto IsCheapPInsrPExtrInsertPS = [&]() { // Assume pinsr/pextr XMM <-> GPR is relatively cheap on all targets. + // Inserting f32 into index0 is just movss. // Also, assume insertps is relatively cheap on all >= SSE41 targets. return (MScalarTy == MVT::i16 && ST->hasSSE2()) || (MScalarTy.isInteger() && ST->hasSSE41()) || + (MScalarTy == MVT::f32 && ST->hasSSE1() && Index == 0 && + Opcode == Instruction::InsertElement) || (MScalarTy == MVT::f32 && ST->hasSSE41() && Opcode == Instruction::InsertElement); }; |