diff options
Diffstat (limited to 'llvm/lib/Target')
69 files changed, 1330 insertions, 494 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 6c46b18..9f8a257 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -1053,13 +1053,6 @@ def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>; def AArch64uaddlv : SDNode<"AArch64ISD::UADDLV", SDT_AArch64uaddlp>; def AArch64saddlv : SDNode<"AArch64ISD::SADDLV", SDT_AArch64uaddlp>; -def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs), - [(abdu node:$lhs, node:$rhs), - (int_aarch64_neon_uabd node:$lhs, node:$rhs)]>; -def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs), - [(abds node:$lhs, node:$rhs), - (int_aarch64_neon_sabd node:$lhs, node:$rhs)]>; - // Add Pairwise of two vectors def AArch64addp_n : SDNode<"AArch64ISD::ADDP", SDT_AArch64Zip>; // Add Long Pairwise @@ -5667,8 +5660,7 @@ let Predicates = [HasFullFP16] in { // Advanced SIMD two vector instructions. //===----------------------------------------------------------------------===// -defm UABDL : SIMDLongThreeVectorBHSabdl<1, 0b0111, "uabdl", - AArch64uabd>; +defm UABDL : SIMDLongThreeVectorBHSabdl<1, 0b0111, "uabdl", abdu>; // Match UABDL in log2-shuffle patterns. def : Pat<(abs (v8i16 (sub (zext (v8i8 V64:$opA)), (zext (v8i8 V64:$opB))))), @@ -6018,8 +6010,8 @@ defm MLS : SIMDThreeSameVectorBHSTied<1, 0b10010, "mls", null_frag>; defm MUL : SIMDThreeSameVectorBHS<0, 0b10011, "mul", mul>; defm PMUL : SIMDThreeSameVectorB<1, 0b10011, "pmul", int_aarch64_neon_pmul>; defm SABA : SIMDThreeSameVectorBHSTied<0, 0b01111, "saba", - TriOpFrag<(add node:$LHS, (AArch64sabd node:$MHS, node:$RHS))> >; -defm SABD : SIMDThreeSameVectorBHS<0,0b01110,"sabd", AArch64sabd>; + TriOpFrag<(add node:$LHS, (abds node:$MHS, node:$RHS))> >; +defm SABD : SIMDThreeSameVectorBHS<0,0b01110,"sabd", abds>; defm SHADD : SIMDThreeSameVectorBHS<0,0b00000,"shadd", avgfloors>; defm SHSUB : SIMDThreeSameVectorBHS<0,0b00100,"shsub", int_aarch64_neon_shsub>; defm SMAXP : SIMDThreeSameVectorBHS<0,0b10100,"smaxp", int_aarch64_neon_smaxp>; @@ -6037,8 +6029,8 @@ defm SRSHL : SIMDThreeSameVector<0,0b01010,"srshl", int_aarch64_neon_srshl>; defm SSHL : SIMDThreeSameVector<0,0b01000,"sshl", int_aarch64_neon_sshl>; defm SUB : SIMDThreeSameVector<1,0b10000,"sub", sub>; defm UABA : SIMDThreeSameVectorBHSTied<1, 0b01111, "uaba", - TriOpFrag<(add node:$LHS, (AArch64uabd node:$MHS, node:$RHS))> >; -defm UABD : SIMDThreeSameVectorBHS<1,0b01110,"uabd", AArch64uabd>; + TriOpFrag<(add node:$LHS, (abdu node:$MHS, node:$RHS))> >; +defm UABD : SIMDThreeSameVectorBHS<1,0b01110,"uabd", abdu>; defm UHADD : SIMDThreeSameVectorBHS<1,0b00000,"uhadd", avgflooru>; defm UHSUB : SIMDThreeSameVectorBHS<1,0b00100,"uhsub", int_aarch64_neon_uhsub>; defm UMAXP : SIMDThreeSameVectorBHS<1,0b10100,"umaxp", int_aarch64_neon_umaxp>; @@ -6759,10 +6751,8 @@ defm SUBHN : SIMDNarrowThreeVectorBHS<0,0b0110,"subhn", int_aarch64_neon_subhn> defm RADDHN : SIMDNarrowThreeVectorBHS<1,0b0100,"raddhn",int_aarch64_neon_raddhn>; defm RSUBHN : SIMDNarrowThreeVectorBHS<1,0b0110,"rsubhn",int_aarch64_neon_rsubhn>; defm PMULL : SIMDDifferentThreeVectorBD<0,0b1110,"pmull", AArch64pmull>; -defm SABAL : SIMDLongThreeVectorTiedBHSabal<0,0b0101,"sabal", - AArch64sabd>; -defm SABDL : SIMDLongThreeVectorBHSabdl<0, 0b0111, "sabdl", - AArch64sabd>; +defm SABAL : SIMDLongThreeVectorTiedBHSabal<0,0b0101,"sabal", abds>; +defm SABDL : SIMDLongThreeVectorBHSabdl<0, 0b0111, "sabdl", abds>; defm SADDL : SIMDLongThreeVectorBHS< 0, 0b0000, "saddl", BinOpFrag<(add (sext node:$LHS), (sext node:$RHS))>>; defm SADDW : SIMDWideThreeVectorBHS< 0, 0b0001, "saddw", @@ -6780,8 +6770,7 @@ defm SSUBL : SIMDLongThreeVectorBHS<0, 0b0010, "ssubl", BinOpFrag<(sub (sext node:$LHS), (sext node:$RHS))>>; defm SSUBW : SIMDWideThreeVectorBHS<0, 0b0011, "ssubw", BinOpFrag<(sub node:$LHS, (sext node:$RHS))>>; -defm UABAL : SIMDLongThreeVectorTiedBHSabal<1, 0b0101, "uabal", - AArch64uabd>; +defm UABAL : SIMDLongThreeVectorTiedBHSabal<1, 0b0101, "uabal", abdu>; defm UADDL : SIMDLongThreeVectorBHS<1, 0b0000, "uaddl", BinOpFrag<(add (zanyext node:$LHS), (zanyext node:$RHS))>>; defm UADDW : SIMDWideThreeVectorBHS<1, 0b0001, "uaddw", diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp index 473ba5e..bb0f667b 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp @@ -287,6 +287,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .moreElementsToNextPow2(0) .lower(); + getActionDefinitionsBuilder({G_ABDS, G_ABDU}) + .legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32}) + .lower(); + getActionDefinitionsBuilder( {G_SADDE, G_SSUBE, G_UADDE, G_USUBE, G_SADDO, G_SSUBO, G_UADDO, G_USUBO}) .legalFor({{s32, s32}, {s64, s32}}) @@ -1794,6 +1798,10 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, return LowerBinOp(AArch64::G_SMULL); case Intrinsic::aarch64_neon_umull: return LowerBinOp(AArch64::G_UMULL); + case Intrinsic::aarch64_neon_sabd: + return LowerBinOp(TargetOpcode::G_ABDS); + case Intrinsic::aarch64_neon_uabd: + return LowerBinOp(TargetOpcode::G_ABDU); case Intrinsic::aarch64_neon_abs: { // Lower the intrinsic to G_ABS. MIB.buildInstr(TargetOpcode::G_ABS, {MI.getOperand(0)}, {MI.getOperand(2)}); diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td index 0e0e83b..6076ac4 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPU.td +++ b/llvm/lib/Target/AMDGPU/AMDGPU.td @@ -1848,7 +1848,8 @@ def FeatureISAVersion11_Common : FeatureSet< FeatureImageInsts, FeaturePackedTID, FeatureVcmpxPermlaneHazard, - FeatureMemoryAtomicFAddF32DenormalSupport]>; + FeatureMemoryAtomicFAddF32DenormalSupport, + FeatureRealTrue16Insts]>; // There are few workarounds that need to be // added to all targets. This pessimizes codegen @@ -1868,8 +1869,7 @@ def FeatureISAVersion11_0_Common : FeatureSet< [FeatureMSAALoadDstSelBug, FeatureVALUTransUseHazard, FeatureMADIntraFwdBug, - FeaturePrivEnabledTrap2NopBug, - FeatureRealTrue16Insts])>; + FeaturePrivEnabledTrap2NopBug])>; def FeatureISAVersion11_0_0 : FeatureSet< !listconcat(FeatureISAVersion11_0_Common.Features, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp index 749b9ef..4b3dc37 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp @@ -1415,6 +1415,7 @@ static void EmitPALMetadataCommon(AMDGPUPALMetadata *MD, MD->setHwStage(CC, ".wgp_mode", (bool)CurrentProgramInfo.WgpMode); MD->setHwStage(CC, ".mem_ordered", (bool)CurrentProgramInfo.MemOrdered); + MD->setHwStage(CC, ".forward_progress", (bool)CurrentProgramInfo.FwdProgress); if (AMDGPU::isCompute(CC)) { MD->setHwStage(CC, ".trap_present", diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp index 14101e5..3d8d274 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp @@ -374,8 +374,10 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val, return true; } - unsigned ReturnOpc = - IsShader ? AMDGPU::SI_RETURN_TO_EPILOG : AMDGPU::SI_RETURN; + const bool IsWholeWave = MFI->isWholeWaveFunction(); + unsigned ReturnOpc = IsWholeWave ? AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN + : IsShader ? AMDGPU::SI_RETURN_TO_EPILOG + : AMDGPU::SI_RETURN; auto Ret = B.buildInstrNoInsert(ReturnOpc); if (!FLI.CanLowerReturn) @@ -383,6 +385,9 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val, else if (!lowerReturnVal(B, Val, VRegs, Ret)) return false; + if (IsWholeWave) + addOriginalExecToReturn(B.getMF(), Ret); + // TODO: Handle CalleeSavedRegsViaCopy. B.insertInstr(Ret); @@ -632,6 +637,17 @@ bool AMDGPUCallLowering::lowerFormalArguments( if (DL.getTypeStoreSize(Arg.getType()) == 0) continue; + if (Info->isWholeWaveFunction() && Idx == 0) { + assert(VRegs[Idx].size() == 1 && "Expected only one register"); + + // The first argument for whole wave functions is the original EXEC value. + B.buildInstr(AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP) + .addDef(VRegs[Idx][0]); + + ++Idx; + continue; + } + const bool InReg = Arg.hasAttribute(Attribute::InReg); if (Arg.hasAttribute(Attribute::SwiftSelf) || @@ -1347,6 +1363,7 @@ bool AMDGPUCallLowering::lowerTailCall( SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs; if (Info.CallConv != CallingConv::AMDGPU_Gfx && + Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave && !AMDGPU::isChainCC(Info.CallConv)) { // With a fixed ABI, allocate fixed registers before user arguments. if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info)) @@ -1524,7 +1541,8 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // after the ordinary user argument registers. SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs; - if (Info.CallConv != CallingConv::AMDGPU_Gfx) { + if (Info.CallConv != CallingConv::AMDGPU_Gfx && + Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave) { // With a fixed ABI, allocate fixed registers before user arguments. if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info)) return false; @@ -1592,3 +1610,11 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, return true; } + +void AMDGPUCallLowering::addOriginalExecToReturn( + MachineFunction &MF, MachineInstrBuilder &Ret) const { + const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); + const SIInstrInfo *TII = ST.getInstrInfo(); + const MachineInstr *Setup = TII->getWholeWaveFunctionSetup(MF); + Ret.addReg(Setup->getOperand(0).getReg()); +} diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h index a6e801f..e0033d5 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h @@ -37,6 +37,9 @@ class AMDGPUCallLowering final : public CallLowering { bool lowerReturnVal(MachineIRBuilder &B, const Value *Val, ArrayRef<Register> VRegs, MachineInstrBuilder &Ret) const; + void addOriginalExecToReturn(MachineFunction &MF, + MachineInstrBuilder &Ret) const; + public: AMDGPUCallLowering(const AMDGPUTargetLowering &TLI); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td index 2bfd56f..891d362 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td @@ -315,6 +315,10 @@ def : GINodeEquiv<G_AMDGPU_S_BUFFER_LOAD_SSHORT, SIsbuffer_load_short>; def : GINodeEquiv<G_AMDGPU_S_BUFFER_LOAD_USHORT, SIsbuffer_load_ushort>; def : GINodeEquiv<G_AMDGPU_S_BUFFER_PREFETCH, SIsbuffer_prefetch>; +def : GINodeEquiv<G_AMDGPU_WHOLE_WAVE_FUNC_SETUP, AMDGPUwhole_wave_setup>; +// G_AMDGPU_WHOLE_WAVE_FUNC_RETURN is simpler than AMDGPUwhole_wave_return, +// so we don't mark it as equivalent. + class GISelSop2Pat < SDPatternOperator node, Instruction inst, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 3d040fb..e3ca09e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -375,7 +375,6 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, setTruncStoreAction(MVT::v16f64, MVT::v16bf16, Expand); setTruncStoreAction(MVT::v16f64, MVT::v16f16, Expand); setTruncStoreAction(MVT::v16i64, MVT::v16i16, Expand); - setTruncStoreAction(MVT::v16i64, MVT::v16i16, Expand); setTruncStoreAction(MVT::v16i64, MVT::v16i8, Expand); setTruncStoreAction(MVT::v16i64, MVT::v16i8, Expand); setTruncStoreAction(MVT::v16i64, MVT::v16i1, Expand); @@ -1143,6 +1142,7 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForCall(CallingConv::ID CC, case CallingConv::Cold: return CC_AMDGPU_Func; case CallingConv::AMDGPU_Gfx: + case CallingConv::AMDGPU_Gfx_WholeWave: return CC_SI_Gfx; case CallingConv::AMDGPU_KERNEL: case CallingConv::SPIR_KERNEL: @@ -1168,6 +1168,7 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForReturn(CallingConv::ID CC, case CallingConv::AMDGPU_LS: return RetCC_SI_Shader; case CallingConv::AMDGPU_Gfx: + case CallingConv::AMDGPU_Gfx_WholeWave: return RetCC_SI_Gfx; case CallingConv::C: case CallingConv::Fast: @@ -5875,6 +5876,8 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(BUFFER_ATOMIC_FMIN) NODE_NAME_CASE(BUFFER_ATOMIC_FMAX) NODE_NAME_CASE(BUFFER_ATOMIC_COND_SUB_U32) + NODE_NAME_CASE(WHOLE_WAVE_SETUP) + NODE_NAME_CASE(WHOLE_WAVE_RETURN) } return nullptr; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h index 4e8c6c7..39bb0ad 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h @@ -608,6 +608,12 @@ enum NodeType : unsigned { BUFFER_ATOMIC_FMAX, BUFFER_ATOMIC_COND_SUB_U32, LAST_MEMORY_OPCODE = BUFFER_ATOMIC_COND_SUB_U32, + + // Set up a whole wave function. + WHOLE_WAVE_SETUP, + + // Return from a whole wave function. + WHOLE_WAVE_RETURN, }; } // End namespace AMDGPUISD diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp index e2c2e89..f2207ff 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp @@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { NewII->takeName(&II); return IC.replaceInstUsesWith(II, NewII); } + case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: { + Value *Src0 = II.getArgOperand(1); + Value *Src1 = II.getArgOperand(3); + unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue(); + uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue(); + auto *Src0Ty = cast<FixedVectorType>(Src0->getType()); + auto *Src1Ty = cast<FixedVectorType>(Src1->getType()); + + bool MadeChange = false; + unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA); + unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB); + + // Depending on the used format, fewer registers are required so shrink the + // vector type. + if (Src0Ty->getNumElements() > Src0NumElts) { + Src0 = IC.Builder.CreateExtractVector( + FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0, + IC.Builder.getInt64(0)); + MadeChange = true; + } + + if (Src1Ty->getNumElements() > Src1NumElts) { + Src1 = IC.Builder.CreateExtractVector( + FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1, + IC.Builder.getInt64(0)); + MadeChange = true; + } + + if (!MadeChange) + return std::nullopt; + + SmallVector<Value *, 13> Args(II.args()); + Args[1] = Src0; + Args[3] = Src1; + + CallInst *NewII = IC.Builder.CreateIntrinsic( + IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()}, + Args, &II); + NewII->takeName(&II); + return IC.replaceInstUsesWith(II, NewII); + } } if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr = AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) { diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td index ce58e93..e305f08 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td @@ -348,6 +348,17 @@ def AMDGPUfdot2_impl : SDNode<"AMDGPUISD::FDOT2", def AMDGPUperm_impl : SDNode<"AMDGPUISD::PERM", AMDGPUDTIntTernaryOp, []>; +// Marks the entry into a whole wave function. +def AMDGPUwhole_wave_setup : SDNode< + "AMDGPUISD::WHOLE_WAVE_SETUP", SDTypeProfile<1, 0, [SDTCisInt<0>]>, + [SDNPHasChain, SDNPSideEffect]>; + +// Marks the return from a whole wave function. +def AMDGPUwhole_wave_return : SDNode< + "AMDGPUISD::WHOLE_WAVE_RETURN", SDTNone, + [SDNPHasChain, SDNPOptInGlue, SDNPVariadic] +>; + // SI+ export def AMDGPUExportOp : SDTypeProfile<0, 8, [ SDTCisInt<0>, // i8 tgt diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp index d161c03..8975486 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -4160,6 +4160,10 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) { return true; case AMDGPU::G_AMDGPU_WAVE_ADDRESS: return selectWaveAddress(I); + case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN: { + I.setDesc(TII.get(AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN)); + return true; + } case AMDGPU::G_STACKRESTORE: return selectStackRestore(I); case AMDGPU::G_PHI: diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp index fa8af68..304e91e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp @@ -1583,15 +1583,13 @@ void SplitPtrStructs::killAndReplaceSplitInstructions( if (!SplitUsers.contains(I)) continue; - SmallVector<DbgValueInst *> Dbgs; - findDbgValues(Dbgs, I); - for (auto *Dbg : Dbgs) { - IRB.SetInsertPoint(Dbg); + SmallVector<DbgVariableRecord *> Dbgs; + findDbgValues(I, Dbgs); + for (DbgVariableRecord *Dbg : Dbgs) { auto &DL = I->getDataLayout(); assert(isSplitFatPtr(I->getType()) && "We should've RAUW'd away loads, stores, etc. at this point"); - auto *OffDbg = cast<DbgValueInst>(Dbg->clone()); - copyMetadata(OffDbg, Dbg); + DbgVariableRecord *OffDbg = Dbg->clone(); auto [Rsrc, Off] = getPtrParts(I); int64_t RsrcSz = DL.getTypeSizeInBits(Rsrc->getType()); @@ -1606,9 +1604,9 @@ void SplitPtrStructs::killAndReplaceSplitInstructions( if (OffExpr) { OffDbg->setExpression(*OffExpr); OffDbg->replaceVariableLocationOp(I, Off); - IRB.Insert(OffDbg); + OffDbg->insertBefore(Dbg); } else { - OffDbg->deleteValue(); + OffDbg->eraseFromParent(); } if (RsrcExpr) { Dbg->setExpression(*RsrcExpr); diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp index bf2f37b..f1caf24 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp @@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8: case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8: case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8: + case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: case Intrinsic::amdgcn_wmma_f32_32x16x128_f4: case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16: case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16: @@ -5540,6 +5541,10 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case AMDGPU::G_PREFETCH: OpdsMapping[0] = getSGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI); break; + case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP: + case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN: + OpdsMapping[0] = AMDGPU::getValueMapping(AMDGPU::VCCRegBankID, 1); + break; } return getInstructionMapping(/*ID*/1, /*Cost*/1, diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index de17fcc..dc83230 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -176,6 +176,8 @@ public: ImmTyWaitVAVDst, ImmTyWaitVMVSrc, ImmTyBitOp3, + ImmTyMatrixAFMT, + ImmTyMatrixBFMT, ImmTyMatrixAReuse, ImmTyMatrixBReuse, ImmTyByteSel, @@ -423,6 +425,8 @@ public: bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); } bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); } bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); } + bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); } + bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); } bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); } bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); } bool isTFE() const { return isImmTy(ImmTyTFE); } @@ -1174,6 +1178,8 @@ public: case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break; case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break; case ImmTyBitOp3: OS << "BitOp3"; break; + case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break; + case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break; case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break; case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break; case ImmTyByteSel: OS << "ByteSel" ; break; @@ -1714,6 +1720,10 @@ public: ParseStatus parseIndexKey8bit(OperandVector &Operands); ParseStatus parseIndexKey16bit(OperandVector &Operands); ParseStatus parseIndexKey32bit(OperandVector &Operands); + ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name, + AMDGPUOperand::ImmTy Type); + ParseStatus parseMatrixAFMT(OperandVector &Operands); + ParseStatus parseMatrixBFMT(OperandVector &Operands); ParseStatus parseDfmtNfmt(int64_t &Format); ParseStatus parseUfmt(int64_t &Format); @@ -1849,6 +1859,7 @@ private: const unsigned CPol); bool validateTFE(const MCInst &Inst, const OperandVector &Operands); std::optional<StringRef> validateLdsDirect(const MCInst &Inst); + bool validateWMMA(const MCInst &Inst, const OperandVector &Operands); unsigned getConstantBusLimit(unsigned Opcode) const; bool usesConstantBus(const MCInst &Inst, unsigned OpIdx); bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const; @@ -5409,6 +5420,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst, return true; } +bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst, + const OperandVector &Operands) { + unsigned Opc = Inst.getOpcode(); + const MCRegisterInfo *TRI = getContext().getRegisterInfo(); + const MCInstrDesc &Desc = MII.get(Opc); + + auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool { + int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp); + if (FmtIdx == -1) + return true; + unsigned Fmt = Inst.getOperand(FmtIdx).getImm(); + int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp); + unsigned RegSize = + TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits(); + + if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32) + return true; + + static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8", + "MATRIX_FMT_FP6", "MATRIX_FMT_BF6", + "MATRIX_FMT_FP4"}; + + Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands), + "wrong register tuple size for " + Twine(FmtNames[Fmt])); + return false; + }; + + return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) && + validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1); +} + bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst, const SMLoc &IDLoc, const OperandVector &Operands) { @@ -5542,6 +5584,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst, if (!validateTFE(Inst, Operands)) { return false; } + if (!validateWMMA(Inst, Operands)) { + return false; + } return true; } @@ -7215,6 +7260,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) { return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit); } +ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands, + StringRef Name, + AMDGPUOperand::ImmTy Type) { + return parseStringOrIntWithPrefix(Operands, Name, + {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8", + "MATRIX_FMT_FP6", "MATRIX_FMT_BF6", + "MATRIX_FMT_FP4"}, + Type); +} + +ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) { + return tryParseMatrixFMT(Operands, "matrix_a_fmt", + AMDGPUOperand::ImmTyMatrixAFMT); +} + +ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) { + return tryParseMatrixFMT(Operands, "matrix_b_fmt", + AMDGPUOperand::ImmTyMatrixBFMT); +} + // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their // values to live in a joint format operand in the MCInst encoding. ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) { @@ -9316,6 +9381,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands, DefaultVal); } + int MatrixAFMTIdx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt); + if (MatrixAFMTIdx != -1) { + addOptionalImmOperand(Inst, Operands, OptIdx, + AMDGPUOperand::ImmTyMatrixAFMT, 0); + } + + int MatrixBFMTIdx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt); + if (MatrixBFMTIdx != -1) { + addOptionalImmOperand(Inst, Operands, OptIdx, + AMDGPUOperand::ImmTyMatrixBFMT, 0); + } + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse)) addOptionalImmOperand(Inst, Operands, OptIdx, AMDGPUOperand::ImmTyMatrixAReuse, 0); diff --git a/llvm/lib/Target/AMDGPU/BUFInstructions.td b/llvm/lib/Target/AMDGPU/BUFInstructions.td index e994aee..f99e716 100644 --- a/llvm/lib/Target/AMDGPU/BUFInstructions.td +++ b/llvm/lib/Target/AMDGPU/BUFInstructions.td @@ -1488,7 +1488,6 @@ defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format, f32, "BUFFER_STORE_FORMAT_ defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format, i32, "BUFFER_STORE_FORMAT_X">; defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format, v2f32, "BUFFER_STORE_FORMAT_XY">; defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format, v2i32, "BUFFER_STORE_FORMAT_XY">; -defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format, v2i32, "BUFFER_STORE_FORMAT_XY">; defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format, v3f32, "BUFFER_STORE_FORMAT_XYZ">; defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format, v3i32, "BUFFER_STORE_FORMAT_XYZ">; defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format, v4f32, "BUFFER_STORE_FORMAT_XYZW">; diff --git a/llvm/lib/Target/AMDGPU/DSInstructions.td b/llvm/lib/Target/AMDGPU/DSInstructions.td index e219fe0..319cc9d 100644 --- a/llvm/lib/Target/AMDGPU/DSInstructions.td +++ b/llvm/lib/Target/AMDGPU/DSInstructions.td @@ -886,7 +886,6 @@ defm : DSReadPat_mc <DS_READ_I8, i32, "sextloadi8_local">; defm : DSReadPat_mc <DS_READ_U8, i32, "extloadi8_local">; defm : DSReadPat_mc <DS_READ_U8, i32, "zextloadi8_local">; defm : DSReadPat_mc <DS_READ_I16, i32, "sextloadi16_local">; -defm : DSReadPat_mc <DS_READ_I16, i32, "sextloadi16_local">; defm : DSReadPat_mc <DS_READ_U16, i32, "extloadi16_local">; defm : DSReadPat_mc <DS_READ_U16, i32, "zextloadi16_local">; defm : DSReadPat_t16 <DS_READ_I8, i16, "sextloadi8_local">; diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp index 98f7e17..5c1989b 100644 --- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp +++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp @@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size, if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI) convertMAIInst(MI); + if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA) + convertWMMAInst(MI); + int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::vdst_in); if (VDstIn_Idx != -1) { @@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI, return MO.setReg( MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5)); case 8: + if (MCRegister NewReg = MRI.getSubReg( + MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_sub6_sub7)) { + MO.setReg(NewReg); + } + return; + case 12: { + // There is no 384-bit subreg index defined. + MCRegister BaseReg = MRI.getSubReg(MO.getReg(), AMDGPU::sub0); + MCRegister NewReg = MRI.getMatchingSuperReg( + BaseReg, AMDGPU::sub0, &MRI.getRegClass(AMDGPU::VReg_384RegClassID)); + return MO.setReg(NewReg); + } + case 16: // No-op in cases where one operand is still f8/bf8. return; default: - llvm_unreachable("Unexpected size for mfma f8f6f4 operand"); + llvm_unreachable("Unexpected size for mfma/wmma f8f6f4 operand"); } } @@ -1015,6 +1031,35 @@ void AMDGPUDisassembler::convertMAIInst(MCInst &MI) const { AdjustedRegClassOpcode->NumRegsSrcB); } +void AMDGPUDisassembler::convertWMMAInst(MCInst &MI) const { + int FmtAIdx = + AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::matrix_a_fmt); + if (FmtAIdx == -1) + return; + + int FmtBIdx = + AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::matrix_b_fmt); + + unsigned FmtA = MI.getOperand(FmtAIdx).getImm(); + unsigned FmtB = MI.getOperand(FmtBIdx).getImm(); + + const AMDGPU::MFMA_F8F6F4_Info *AdjustedRegClassOpcode = + AMDGPU::getWMMA_F8F6F4_WithFormatArgs(FmtA, FmtB, MI.getOpcode()); + if (!AdjustedRegClassOpcode || + AdjustedRegClassOpcode->Opcode == MI.getOpcode()) + return; + + MI.setOpcode(AdjustedRegClassOpcode->Opcode); + int Src0Idx = + AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src0); + int Src1Idx = + AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src1); + adjustMFMA_F8F6F4OpRegClass(MRI, MI.getOperand(Src0Idx), + AdjustedRegClassOpcode->NumRegsSrcA); + adjustMFMA_F8F6F4OpRegClass(MRI, MI.getOperand(Src1Idx), + AdjustedRegClassOpcode->NumRegsSrcB); +} + struct VOPModifiers { unsigned OpSel = 0; unsigned OpSelHi = 0; diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h index 8404100..f4d164b 100644 --- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h +++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h @@ -161,6 +161,7 @@ public: void convertFMAanyK(MCInst &MI) const; void convertSDWAInst(MCInst &MI) const; void convertMAIInst(MCInst &MI) const; + void convertWMMAInst(MCInst &MI) const; void convertDPP8Inst(MCInst &MI) const; void convertMIMGInst(MCInst &MI) const; void convertVOP3DPPInst(MCInst &MI) const; diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp index bbed828..c4a3be4 100644 --- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp +++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp @@ -3206,7 +3206,7 @@ bool GCNHazardRecognizer::fixRequiredExportPriority(MachineInstr *MI) { // Check entry priority at each export (as there will only be a few). // Note: amdgpu_gfx can only be a callee, so defer to caller setprio. bool Changed = false; - if (CC != CallingConv::AMDGPU_Gfx) + if (CC != CallingConv::AMDGPU_Gfx && CC != CallingConv::AMDGPU_Gfx_WholeWave) Changed = ensureEntrySetPrio(MF, NormalPriority, TII); auto NextMI = std::next(It); diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp index 44d2f94..197bb3f 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp @@ -1345,6 +1345,48 @@ void AMDGPUInstPrinter::printIndexKey32bit(const MCInst *MI, unsigned OpNo, O << " index_key:" << Imm; } +void AMDGPUInstPrinter::printMatrixFMT(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O, char AorB) { + auto Imm = MI->getOperand(OpNo).getImm() & 0x7; + if (Imm == 0) + return; + + O << " matrix_" << AorB << "_fmt:"; + switch (Imm) { + default: + O << Imm; + break; + case WMMA::MatrixFMT::MATRIX_FMT_FP8: + O << "MATRIX_FMT_FP8"; + break; + case WMMA::MatrixFMT::MATRIX_FMT_BF8: + O << "MATRIX_FMT_BF8"; + break; + case WMMA::MatrixFMT::MATRIX_FMT_FP6: + O << "MATRIX_FMT_FP6"; + break; + case WMMA::MatrixFMT::MATRIX_FMT_BF6: + O << "MATRIX_FMT_BF6"; + break; + case WMMA::MatrixFMT::MATRIX_FMT_FP4: + O << "MATRIX_FMT_FP4"; + break; + } +} + +void AMDGPUInstPrinter::printMatrixAFMT(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O) { + printMatrixFMT(MI, OpNo, STI, O, 'a'); +} + +void AMDGPUInstPrinter::printMatrixBFMT(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O) { + printMatrixFMT(MI, OpNo, STI, O, 'b'); +} + void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum, const MCSubtargetInfo &STI, raw_ostream &O) { diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h index e3299a6..e0b7aa5 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h @@ -134,6 +134,12 @@ private: const MCSubtargetInfo &STI, raw_ostream &O); void printIndexKey32bit(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI, raw_ostream &O); + void printMatrixFMT(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O, char AorB); + void printMatrixAFMT(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O); + void printMatrixBFMT(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O); void printInterpSlot(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI, raw_ostream &O); void printInterpAttr(const MCInst *MI, unsigned OpNo, diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp index f48739f..c49ad79 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp @@ -384,6 +384,8 @@ void AMDGPUMCCodeEmitter::encodeInstruction(const MCInst &MI, if (((Desc.TSFlags & SIInstrFlags::VOP3P) || Opcode == AMDGPU::V_ACCVGPR_READ_B32_vi || Opcode == AMDGPU::V_ACCVGPR_WRITE_B32_vi) && + // Matrix B format operand reuses op_sel_hi. + !AMDGPU::hasNamedOperand(Opcode, AMDGPU::OpName::matrix_b_fmt) && // Matrix B reuse operand reuses op_sel_hi. !AMDGPU::hasNamedOperand(Opcode, AMDGPU::OpName::matrix_b_reuse)) { Encoding |= getImplicitOpSelHiEncoding(Opcode); diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h index edc74605..d379088 100644 --- a/llvm/lib/Target/AMDGPU/SIDefines.h +++ b/llvm/lib/Target/AMDGPU/SIDefines.h @@ -1005,6 +1005,16 @@ enum Target : unsigned { } // namespace Exp +namespace WMMA { +enum MatrixFMT : unsigned { + MATRIX_FMT_FP8 = 0, + MATRIX_FMT_BF8 = 1, + MATRIX_FMT_FP6 = 2, + MATRIX_FMT_BF6 = 3, + MATRIX_FMT_FP4 = 4 +}; +} // namespace WMMA + namespace VOP3PEncoding { enum OpSel : uint64_t { diff --git a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp index 6a38679..11552b3 100644 --- a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp @@ -946,8 +946,18 @@ static Register buildScratchExecCopy(LiveRegUnits &LiveUnits, initLiveUnits(LiveUnits, TRI, FuncInfo, MF, MBB, MBBI, IsProlog); - ScratchExecCopy = findScratchNonCalleeSaveRegister( - MRI, LiveUnits, *TRI.getWaveMaskRegClass()); + if (FuncInfo->isWholeWaveFunction()) { + // Whole wave functions already have a copy of the original EXEC mask that + // we can use. + assert(IsProlog && "Epilog should look at return, not setup"); + ScratchExecCopy = + TII->getWholeWaveFunctionSetup(MF)->getOperand(0).getReg(); + assert(ScratchExecCopy && "Couldn't find copy of EXEC"); + } else { + ScratchExecCopy = findScratchNonCalleeSaveRegister( + MRI, LiveUnits, *TRI.getWaveMaskRegClass()); + } + if (!ScratchExecCopy) report_fatal_error("failed to find free scratch register"); @@ -996,10 +1006,15 @@ void SIFrameLowering::emitCSRSpillStores( }; StoreWWMRegisters(WWMScratchRegs); + + auto EnableAllLanes = [&]() { + unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64; + BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1); + }; + if (!WWMCalleeSavedRegs.empty()) { if (ScratchExecCopy) { - unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64; - BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1); + EnableAllLanes(); } else { ScratchExecCopy = buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL, /*IsProlog*/ true, @@ -1008,7 +1023,18 @@ void SIFrameLowering::emitCSRSpillStores( } StoreWWMRegisters(WWMCalleeSavedRegs); - if (ScratchExecCopy) { + if (FuncInfo->isWholeWaveFunction()) { + // SI_WHOLE_WAVE_FUNC_SETUP has outlived its purpose, so we can remove + // it now. If we have already saved some WWM CSR registers, then the EXEC is + // already -1 and we don't need to do anything else. Otherwise, set EXEC to + // -1 here. + if (!ScratchExecCopy) + buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL, /*IsProlog*/ true, + /*EnableInactiveLanes*/ true); + else if (WWMCalleeSavedRegs.empty()) + EnableAllLanes(); + TII->getWholeWaveFunctionSetup(MF)->eraseFromParent(); + } else if (ScratchExecCopy) { // FIXME: Split block and make terminator. unsigned ExecMov = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64; BuildMI(MBB, MBBI, DL, TII->get(ExecMov), TRI.getExec()) @@ -1083,11 +1109,6 @@ void SIFrameLowering::emitCSRSpillRestores( Register ScratchExecCopy; SmallVector<std::pair<Register, int>, 2> WWMCalleeSavedRegs, WWMScratchRegs; FuncInfo->splitWWMSpillRegisters(MF, WWMCalleeSavedRegs, WWMScratchRegs); - if (!WWMScratchRegs.empty()) - ScratchExecCopy = - buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL, - /*IsProlog*/ false, /*EnableInactiveLanes*/ true); - auto RestoreWWMRegisters = [&](SmallVectorImpl<std::pair<Register, int>> &WWMRegs) { for (const auto &Reg : WWMRegs) { @@ -1098,6 +1119,36 @@ void SIFrameLowering::emitCSRSpillRestores( } }; + if (FuncInfo->isWholeWaveFunction()) { + // For whole wave functions, the EXEC is already -1 at this point. + // Therefore, we can restore the CSR WWM registers right away. + RestoreWWMRegisters(WWMCalleeSavedRegs); + + // The original EXEC is the first operand of the return instruction. + const MachineInstr &Return = MBB.instr_back(); + assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN && + "Unexpected return inst"); + Register OrigExec = Return.getOperand(0).getReg(); + + if (!WWMScratchRegs.empty()) { + unsigned XorOpc = ST.isWave32() ? AMDGPU::S_XOR_B32 : AMDGPU::S_XOR_B64; + BuildMI(MBB, MBBI, DL, TII->get(XorOpc), TRI.getExec()) + .addReg(OrigExec) + .addImm(-1); + RestoreWWMRegisters(WWMScratchRegs); + } + + // Restore original EXEC. + unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64; + BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec); + return; + } + + if (!WWMScratchRegs.empty()) { + ScratchExecCopy = + buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL, + /*IsProlog=*/false, /*EnableInactiveLanes=*/true); + } RestoreWWMRegisters(WWMScratchRegs); if (!WWMCalleeSavedRegs.empty()) { if (ScratchExecCopy) { @@ -1634,6 +1685,7 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF, NeedExecCopyReservedReg = true; else if (MI.getOpcode() == AMDGPU::SI_RETURN || MI.getOpcode() == AMDGPU::SI_RETURN_TO_EPILOG || + MI.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN || (MFI->isChainFunction() && TII->isChainCallOpcode(MI.getOpcode()))) { // We expect all return to be the same size. @@ -1662,6 +1714,21 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF, if (MFI->isEntryFunction()) return; + if (MFI->isWholeWaveFunction()) { + // In practice, all the VGPRs are WWM registers, and we will need to save at + // least their inactive lanes. Add them to WWMReservedRegs. + assert(!NeedExecCopyReservedReg && + "Whole wave functions can use the reg mapped for their i1 argument"); + + // FIXME: Be more efficient! + for (MCRegister Reg : AMDGPU::VGPR_32RegClass) + if (MF.getRegInfo().isPhysRegModified(Reg)) { + MFI->reserveWWMRegister(Reg); + MF.begin()->addLiveIn(Reg); + } + MF.begin()->sortUniqueLiveIns(); + } + // Remove any VGPRs used in the return value because these do not need to be saved. // This prevents CSR restore from clobbering return VGPRs. if (ReturnMI) { diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 0c76ff2..bc0fd8d 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -618,6 +618,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM, ISD::FSIN, ISD::FROUND}, MVT::f16, Custom); + // BF16 - VOP1 Actions. + if (Subtarget->hasBF16TransInsts()) + setOperationAction({ISD::FCOS, ISD::FSIN, ISD::FDIV}, MVT::bf16, Custom); + setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::f16, Promote); setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::bf16, Promote); @@ -2260,7 +2264,8 @@ SDValue SITargetLowering::getPreloadedValue( const ArgDescriptor WorkGroupIDZ = ArgDescriptor::createRegister(AMDGPU::TTMP7, 0xFFFF0000u); if (Subtarget->hasArchitectedSGPRs() && - (AMDGPU::isCompute(CC) || CC == CallingConv::AMDGPU_Gfx)) { + (AMDGPU::isCompute(CC) || CC == CallingConv::AMDGPU_Gfx || + CC == CallingConv::AMDGPU_Gfx_WholeWave)) { switch (PVID) { case AMDGPUFunctionArgInfo::WORKGROUP_ID_X: Reg = &WorkGroupIDX; @@ -2942,12 +2947,15 @@ SDValue SITargetLowering::LowerFormalArguments( if (!Subtarget->enableFlatScratch()) assert(!UserSGPRInfo.hasFlatScratchInit()); if ((CallConv != CallingConv::AMDGPU_CS && - CallConv != CallingConv::AMDGPU_Gfx) || + CallConv != CallingConv::AMDGPU_Gfx && + CallConv != CallingConv::AMDGPU_Gfx_WholeWave) || !Subtarget->hasArchitectedSGPRs()) assert(!Info->hasWorkGroupIDX() && !Info->hasWorkGroupIDY() && !Info->hasWorkGroupIDZ()); } + bool IsWholeWaveFunc = Info->isWholeWaveFunction(); + if (CallConv == CallingConv::AMDGPU_PS) { processPSInputArgs(Splits, CallConv, Ins, Skipped, FType, Info); @@ -2988,7 +2996,8 @@ SDValue SITargetLowering::LowerFormalArguments( } else if (IsKernel) { assert(Info->hasWorkGroupIDX() && Info->hasWorkItemIDX()); } else { - Splits.append(Ins.begin(), Ins.end()); + Splits.append(IsWholeWaveFunc ? std::next(Ins.begin()) : Ins.begin(), + Ins.end()); } if (IsKernel) @@ -3019,6 +3028,13 @@ SDValue SITargetLowering::LowerFormalArguments( SmallVector<SDValue, 16> Chains; + if (IsWholeWaveFunc) { + SDValue Setup = DAG.getNode(AMDGPUISD::WHOLE_WAVE_SETUP, DL, + {MVT::i1, MVT::Other}, Chain); + InVals.push_back(Setup.getValue(0)); + Chains.push_back(Setup.getValue(1)); + } + // FIXME: This is the minimum kernel argument alignment. We should improve // this to the maximum alignment of the arguments. // @@ -3026,7 +3042,8 @@ SDValue SITargetLowering::LowerFormalArguments( // kern arg offset. const Align KernelArgBaseAlign = Align(16); - for (unsigned i = 0, e = Ins.size(), ArgIdx = 0; i != e; ++i) { + for (unsigned i = IsWholeWaveFunc ? 1 : 0, e = Ins.size(), ArgIdx = 0; i != e; + ++i) { const ISD::InputArg &Arg = Ins[i]; if ((Arg.isOrigArg() && Skipped[Arg.getOrigArgIndex()]) || IsError) { InVals.push_back(DAG.getPOISON(Arg.VT)); @@ -3374,7 +3391,9 @@ SITargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, unsigned Opc = AMDGPUISD::ENDPGM; if (!IsWaveEnd) - Opc = IsShader ? AMDGPUISD::RETURN_TO_EPILOG : AMDGPUISD::RET_GLUE; + Opc = Info->isWholeWaveFunction() ? AMDGPUISD::WHOLE_WAVE_RETURN + : IsShader ? AMDGPUISD::RETURN_TO_EPILOG + : AMDGPUISD::RET_GLUE; return DAG.getNode(Opc, DL, MVT::Other, RetOps); } @@ -3876,7 +3895,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI, CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext()); CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, IsVarArg); - if (CallConv != CallingConv::AMDGPU_Gfx && !AMDGPU::isChainCC(CallConv)) { + if (CallConv != CallingConv::AMDGPU_Gfx && !AMDGPU::isChainCC(CallConv) && + CallConv != CallingConv::AMDGPU_Gfx_WholeWave) { // With a fixed ABI, allocate fixed registers before user arguments. passSpecialInputs(CLI, CCInfo, *Info, RegsToPass, MemOpChains, Chain); } @@ -5890,6 +5910,18 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); return SplitBB; } + case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: { + assert(MFI->isWholeWaveFunction()); + + // During ISel, it's difficult to propagate the original EXEC mask to use as + // an input to SI_WHOLE_WAVE_FUNC_RETURN. Set it up here instead. + MachineInstr *Setup = TII->getWholeWaveFunctionSetup(*BB->getParent()); + Register OriginalExec = Setup->getOperand(0).getReg(); + assert(Setup && "Couldn't find SI_SETUP_WHOLE_WAVE_FUNC"); + MF->getRegInfo().clearKillFlags(OriginalExec); + MI.getOperand(0).setReg(OriginalExec); + return BB; + } default: if (TII->isImage(MI) || TII->isMUBUF(MI)) { if (!MI.mayStore()) @@ -11172,7 +11204,7 @@ SDValue SITargetLowering::lowerFastUnsafeFDIV(SDValue Op, // Without !fpmath accuracy information, we can't do more because we don't // know exactly whether rcp is accurate enough to meet !fpmath requirement. // f16 is always accurate enough - if (!AllowInaccurateRcp && VT != MVT::f16) + if (!AllowInaccurateRcp && VT != MVT::f16 && VT != MVT::bf16) return SDValue(); if (CLHS->isExactlyValue(1.0)) { @@ -11199,9 +11231,10 @@ SDValue SITargetLowering::lowerFastUnsafeFDIV(SDValue Op, } } - // For f16 require afn or arcp. + // For f16 and bf16 require afn or arcp. // For f32 require afn. - if (!AllowInaccurateRcp && (VT != MVT::f16 || !Flags.hasAllowReciprocal())) + if (!AllowInaccurateRcp && + ((VT != MVT::f16 && VT != MVT::bf16) || !Flags.hasAllowReciprocal())) return SDValue(); // Turn into multiply by the reciprocal. @@ -11592,7 +11625,7 @@ SDValue SITargetLowering::LowerFDIV(SDValue Op, SelectionDAG &DAG) const { if (VT == MVT::f64) return LowerFDIV64(Op, DAG); - if (VT == MVT::f16) + if (VT == MVT::f16 || VT == MVT::bf16) return LowerFDIV16(Op, DAG); llvm_unreachable("Unexpected type for fdiv"); diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp index 2af0a57..9faf497 100644 --- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp +++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp @@ -1812,6 +1812,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI, // with knowledge of the called routines. if (MI.getOpcode() == AMDGPU::SI_RETURN_TO_EPILOG || MI.getOpcode() == AMDGPU::SI_RETURN || + MI.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN || MI.getOpcode() == AMDGPU::S_SETPC_B64_return || (MI.isReturn() && MI.isCall() && !callWaitsOnFunctionEntry(MI))) { Wait = Wait.combined(WCG->getAllZeroWaitcnt(/*IncludeVSCnt=*/false)); diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index c8935f0..e2a2525 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -2472,6 +2472,7 @@ bool SIInstrInfo::expandPostRAPseudo(MachineInstr &MI) const { MI.setDesc(get(ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64)); break; } + case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: case AMDGPU::SI_RETURN: { const MachineFunction *MF = MBB.getParent(); const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>(); @@ -5757,6 +5758,19 @@ void SIInstrInfo::restoreExec(MachineFunction &MF, MachineBasicBlock &MBB, Indexes->insertMachineInstrInMaps(*ExecRestoreMI); } +MachineInstr * +SIInstrInfo::getWholeWaveFunctionSetup(MachineFunction &MF) const { + assert(MF.getInfo<SIMachineFunctionInfo>()->isWholeWaveFunction() && + "Not a whole wave func"); + MachineBasicBlock &MBB = *MF.begin(); + for (MachineInstr &MI : MBB) + if (MI.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_SETUP || + MI.getOpcode() == AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP) + return &MI; + + llvm_unreachable("Couldn't find SI_SETUP_WHOLE_WAVE_FUNC instruction"); +} + static const TargetRegisterClass * adjustAllocatableRegClass(const GCNSubtarget &ST, const SIRegisterInfo &RI, const MachineRegisterInfo &MRI, diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h index 5e92921..800ea9a 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h @@ -1215,6 +1215,8 @@ public: MachineBasicBlock::iterator MBBI, const DebugLoc &DL, Register Reg, SlotIndexes *Indexes = nullptr) const; + MachineInstr *getWholeWaveFunctionSetup(MachineFunction &MF) const; + /// Return the correct register class for \p OpNo. For target-specific /// instructions, this will return the register class that has been defined /// in tablegen. For generic instructions, like REG_SEQUENCE it will return diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td index 9e1951e..bd4995b 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td @@ -1307,6 +1307,9 @@ let PrintMethod = "printBitOp3" in def BitOp3 : NamedIntOperand<"bitop3">; def bitop3_0 : DefaultOperand<BitOp3, 0>; +def MatrixAFMT : CustomOperand<i32, 1, "MatrixAFMT">; +def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">; + def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">; def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">; @@ -1882,6 +1885,7 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> { !eq(VT, v4bf16) : AVSrc_64, !eq(VT.Size, 1024) : VRegSrc_1024, !eq(VT.Size, 512) : VRegSrc_512, + !eq(VT.Size, 384) : VRegSrc_384, !eq(VT.Size, 256) : VRegSrc_256, !eq(VT.Size, 192) : VRegSrc_192, !eq(VT.Size, 128) : VRegSrc_128, @@ -1894,6 +1898,7 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> { class getVOP3VRegSrcForVT<ValueType VT> { RegisterOperand ret = !cond(!eq(VT.Size, 1024) : VRegSrc_1024, !eq(VT.Size, 512) : VRegSrc_512, + !eq(VT.Size, 384) : VRegSrc_384, !eq(VT.Size, 256) : VRegSrc_256, !eq(VT.Size, 192) : VRegSrc_192, !eq(VT.Size, 128) : VRegSrc_128, @@ -2666,6 +2671,7 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> { HasOMod); field bit HasNeg = HasModifiers; field bit HasMatrixReuse = 0; + field bit HasMatrixFMT = 0; field bit HasSrc0Mods = HasModifiers; field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0); diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td index 991d9f8..d05be8f 100644 --- a/llvm/lib/Target/AMDGPU/SIInstructions.td +++ b/llvm/lib/Target/AMDGPU/SIInstructions.td @@ -644,6 +644,32 @@ def SI_INIT_WHOLE_WAVE : SPseudoInstSI < let isConvergent = 1; } +// Sets EXEC to all lanes and returns the previous EXEC. +def SI_WHOLE_WAVE_FUNC_SETUP : SPseudoInstSI < + (outs SReg_1:$dst), (ins), [(set i1:$dst, (AMDGPUwhole_wave_setup))]> { + let Defs = [EXEC]; + let Uses = [EXEC]; + + let isConvergent = 1; +} + +// Restores the previous EXEC and otherwise behaves entirely like a SI_RETURN. +def SI_WHOLE_WAVE_FUNC_RETURN : SPseudoInstSI < + (outs), (ins SReg_1:$orig_exec)> { + let isTerminator = 1; + let isBarrier = 1; + let isReturn = 1; + let SchedRW = [WriteBranch]; + + // We're going to use custom handling to set the $orig_exec to the correct value. + let usesCustomInserter = 1; +} + +// Generate a SI_WHOLE_WAVE_FUNC_RETURN pseudo with a placeholder for its +// argument. It will be filled in by the custom inserter. +def : GCNPat< + (AMDGPUwhole_wave_return), (SI_WHOLE_WAVE_FUNC_RETURN (i1 (IMPLICIT_DEF)))>; + // Return for returning shaders to a shader variant epilog. def SI_RETURN_TO_EPILOG : SPseudoInstSI < (outs), (ins variable_ops), [(AMDGPUreturn_to_epilog)]> { @@ -2473,6 +2499,7 @@ def : AMDGPUPat < >; let True16Predicate = NotHasTrue16BitInsts in { +let SubtargetPredicate = isNotGFX9Plus in { def : ROTRPattern <V_ALIGNBIT_B32_e64>; def : GCNPat<(i32 (trunc (srl i64:$src0, (and i32:$src1, (i32 31))))), @@ -2482,6 +2509,35 @@ def : GCNPat<(i32 (trunc (srl i64:$src0, (and i32:$src1, (i32 31))))), def : GCNPat<(i32 (trunc (srl i64:$src0, (i32 ShiftAmt32Imm:$src1)))), (V_ALIGNBIT_B32_e64 (i32 (EXTRACT_SUBREG (i64 $src0), sub1)), (i32 (EXTRACT_SUBREG (i64 $src0), sub0)), $src1)>; +} // isNotGFX9Plus + +let SubtargetPredicate = isGFX9GFX10 in { +def : GCNPat < + (rotr i32:$src0, i32:$src1), + (V_ALIGNBIT_B32_opsel_e64 /* src0_modifiers */ 0, $src0, + /* src1_modifiers */ 0, $src0, + /* src2_modifiers */ 0, + $src1, /* clamp */ 0, /* op_sel */ 0) +>; + +foreach pat = [(i32 (trunc (srl i64:$src0, (and i32:$src1, (i32 31))))), + (i32 (trunc (srl i64:$src0, (i32 ShiftAmt32Imm:$src1))))] in +def : GCNPat<pat, + (V_ALIGNBIT_B32_opsel_e64 0, /* src0_modifiers */ + (i32 (EXTRACT_SUBREG (i64 $src0), sub1)), + 0, /* src1_modifiers */ + (i32 (EXTRACT_SUBREG (i64 $src0), sub0)), + 0, /* src2_modifiers */ + $src1, /* clamp */ 0, /* op_sel */ 0) +>; + +def : GCNPat<(fshr i32:$src0, i32:$src1, i32:$src2), + (V_ALIGNBIT_B32_opsel_e64 /* src0_modifiers */ 0, $src0, + /* src1_modifiers */ 0, $src1, + /* src2_modifiers */ 0, + $src2, /* clamp */ 0, /* op_sel */ 0) +>; +} // isGFX9GFX10 } // end True16Predicate = NotHasTrue16BitInsts let True16Predicate = UseRealTrue16Insts in { @@ -3082,6 +3138,8 @@ def : GCNPat < (i32 (EXTRACT_SUBREG $a, sub0))), (i32 1)) >; +// This pattern for bswap is used for pre-GFX8. For GFX8+, bswap is mapped +// to V_PERM_B32. let True16Predicate = NotHasTrue16BitInsts in def : GCNPat < (i32 (bswap i32:$a)), @@ -3559,15 +3617,20 @@ def : GCNPat < // Take the upper 16 bits from V[0] and the lower 16 bits from V[1] // Special case, can use V_ALIGNBIT (always uses encoded literal) -let True16Predicate = NotHasTrue16BitInsts in -def : GCNPat < +let True16Predicate = NotHasTrue16BitInsts in { +defvar BuildVectorToAlignBitPat = (vecTy (DivergentBinFrag<build_vector> (Ty !if(!eq(Ty, i16), (Ty (trunc (srl VGPR_32:$a, (i32 16)))), (Ty (bitconvert (i16 (trunc (srl VGPR_32:$a, (i32 16)))))))), - (Ty VGPR_32:$b))), - (V_ALIGNBIT_B32_e64 VGPR_32:$b, VGPR_32:$a, (i32 16)) ->; + (Ty VGPR_32:$b))); + +let SubtargetPredicate = isNotGFX9Plus in +def : GCNPat<BuildVectorToAlignBitPat, (V_ALIGNBIT_B32_e64 VGPR_32:$b, VGPR_32:$a, (i32 16))>; + +let SubtargetPredicate = isGFX9GFX10 in +def : GCNPat<BuildVectorToAlignBitPat, (V_ALIGNBIT_B32_opsel_e64 0, VGPR_32:$b, 0, VGPR_32:$a, 0, (i32 16), 0, 0)>; +} //True16Predicate = NotHasTrue16BitInsts let True16Predicate = UseFakeTrue16Insts in def : GCNPat < @@ -4300,6 +4363,20 @@ def G_AMDGPU_S_MUL_I64_I32 : AMDGPUGenericInstruction { let hasSideEffects = 0; } +def G_AMDGPU_WHOLE_WAVE_FUNC_SETUP : AMDGPUGenericInstruction { + let OutOperandList = (outs type0:$origExec); + let InOperandList = (ins); + let isConvergent = 1; +} + +def G_AMDGPU_WHOLE_WAVE_FUNC_RETURN : AMDGPUGenericInstruction { + let OutOperandList = (outs); + let InOperandList = (ins type0:$origExec); + let isTerminator = 1; + let isBarrier = 1; + let isReturn = 1; +} + // This is equivalent to the G_INTRINSIC*, but the operands may have // been legalized depending on the subtarget requirements. def G_AMDGPU_INTRIN_IMAGE_LOAD : AMDGPUGenericInstruction { diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp index 8c2e9b62..f0be204 100644 --- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp @@ -51,7 +51,9 @@ SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F, WorkGroupIDZ(false), WorkGroupInfo(false), LDSKernelId(false), PrivateSegmentWaveByteOffset(false), WorkItemIDX(false), WorkItemIDY(false), WorkItemIDZ(false), ImplicitArgPtr(false), - GITPtrHigh(0xffffffff), HighBitsOf32BitAddress(0) { + GITPtrHigh(0xffffffff), HighBitsOf32BitAddress(0), + IsWholeWaveFunction(F.getCallingConv() == + CallingConv::AMDGPU_Gfx_WholeWave) { const GCNSubtarget &ST = *STI; FlatWorkGroupSizes = ST.getFlatWorkGroupSizes(F); WavesPerEU = ST.getWavesPerEU(F); @@ -99,7 +101,8 @@ SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F, ImplicitArgPtr = false; } else if (!isEntryFunction()) { - if (CC != CallingConv::AMDGPU_Gfx) + if (CC != CallingConv::AMDGPU_Gfx && + CC != CallingConv::AMDGPU_Gfx_WholeWave) ArgInfo = AMDGPUArgumentUsageInfo::FixedABIFunctionInfo; FrameOffsetReg = AMDGPU::SGPR33; @@ -732,6 +735,7 @@ yaml::SIMachineFunctionInfo::SIMachineFunctionInfo( PSInputAddr(MFI.getPSInputAddr()), PSInputEnable(MFI.getPSInputEnable()), MaxMemoryClusterDWords(MFI.getMaxMemoryClusterDWords()), Mode(MFI.getMode()), HasInitWholeWave(MFI.hasInitWholeWave()), + IsWholeWaveFunction(MFI.isWholeWaveFunction()), DynamicVGPRBlockSize(MFI.getDynamicVGPRBlockSize()), ScratchReservedForDynamicVGPRs(MFI.getScratchReservedForDynamicVGPRs()) { for (Register Reg : MFI.getSGPRSpillPhysVGPRs()) @@ -778,6 +782,7 @@ bool SIMachineFunctionInfo::initializeBaseYamlFields( HasSpilledVGPRs = YamlMFI.HasSpilledVGPRs; BytesInStackArgArea = YamlMFI.BytesInStackArgArea; ReturnsVoid = YamlMFI.ReturnsVoid; + IsWholeWaveFunction = YamlMFI.IsWholeWaveFunction; if (YamlMFI.ScavengeFI) { auto FIOrErr = YamlMFI.ScavengeFI->getFI(MF.getFrameInfo()); diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h index 274a60ad..08b0206 100644 --- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h +++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h @@ -298,6 +298,7 @@ struct SIMachineFunctionInfo final : public yaml::MachineFunctionInfo { StringValue LongBranchReservedReg; bool HasInitWholeWave = false; + bool IsWholeWaveFunction = false; unsigned DynamicVGPRBlockSize = 0; unsigned ScratchReservedForDynamicVGPRs = 0; @@ -356,6 +357,7 @@ template <> struct MappingTraits<SIMachineFunctionInfo> { YamlIO.mapOptional("dynamicVGPRBlockSize", MFI.DynamicVGPRBlockSize, false); YamlIO.mapOptional("scratchReservedForDynamicVGPRs", MFI.ScratchReservedForDynamicVGPRs, 0); + YamlIO.mapOptional("isWholeWaveFunction", MFI.IsWholeWaveFunction, false); } }; @@ -565,6 +567,8 @@ private: // the serialization easier. ReservedRegSet WWMReservedRegs; + bool IsWholeWaveFunction = false; + using PrologEpilogSGPRSpill = std::pair<Register, PrologEpilogSGPRSaveRestoreInfo>; // To track the SGPR spill method used for a CSR SGPR register during @@ -670,6 +674,8 @@ public: return WWMReservedRegs.contains(Reg); } + bool isWholeWaveFunction() const { return IsWholeWaveFunction; } + ArrayRef<PrologEpilogSGPRSpill> getPrologEpilogSGPRSpills() const { assert(is_sorted(PrologEpilogSGPRSpills, llvm::less_first())); return PrologEpilogSGPRSpills; diff --git a/llvm/lib/Target/AMDGPU/SIProgramInfo.cpp b/llvm/lib/Target/AMDGPU/SIProgramInfo.cpp index 7093fe6..5940f45 100644 --- a/llvm/lib/Target/AMDGPU/SIProgramInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIProgramInfo.cpp @@ -85,7 +85,8 @@ static uint64_t getComputePGMRSrc1Reg(const SIProgramInfo &ProgInfo, S_00B848_PRIV(ProgInfo.Priv) | S_00B848_DEBUG_MODE(ProgInfo.DebugMode) | S_00B848_WGP_MODE(ProgInfo.WgpMode) | - S_00B848_MEM_ORDERED(ProgInfo.MemOrdered); + S_00B848_MEM_ORDERED(ProgInfo.MemOrdered) | + S_00B848_FWD_PROGRESS(ProgInfo.FwdProgress); if (ST.hasDX10ClampMode()) Reg |= S_00B848_DX10_CLAMP(ProgInfo.DX10Clamp); @@ -93,10 +94,6 @@ static uint64_t getComputePGMRSrc1Reg(const SIProgramInfo &ProgInfo, if (ST.hasIEEEMode()) Reg |= S_00B848_IEEE_MODE(ProgInfo.IEEEMode); - // TODO: in the long run we will want to enable this unconditionally. - if (ST.getTargetTriple().getOS() == Triple::OSType::AMDHSA) - Reg |= S_00B848_FWD_PROGRESS(ProgInfo.FwdProgress); - if (ST.hasRrWGMode()) Reg |= S_00B848_RR_WG_MODE(ProgInfo.RrWgMode); diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp index fa2b8db..84cfa87 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp @@ -407,6 +407,7 @@ const MCPhysReg *SIRegisterInfo::getCalleeSavedRegs( return ST.hasGFX90AInsts() ? CSR_AMDGPU_GFX90AInsts_SaveList : CSR_AMDGPU_SaveList; case CallingConv::AMDGPU_Gfx: + case CallingConv::AMDGPU_Gfx_WholeWave: return ST.hasGFX90AInsts() ? CSR_AMDGPU_SI_Gfx_GFX90AInsts_SaveList : CSR_AMDGPU_SI_Gfx_SaveList; case CallingConv::AMDGPU_CS_ChainPreserve: @@ -433,6 +434,7 @@ const uint32_t *SIRegisterInfo::getCallPreservedMask(const MachineFunction &MF, return ST.hasGFX90AInsts() ? CSR_AMDGPU_GFX90AInsts_RegMask : CSR_AMDGPU_RegMask; case CallingConv::AMDGPU_Gfx: + case CallingConv::AMDGPU_Gfx_WholeWave: return ST.hasGFX90AInsts() ? CSR_AMDGPU_SI_Gfx_GFX90AInsts_RegMask : CSR_AMDGPU_SI_Gfx_RegMask; case CallingConv::AMDGPU_CS_Chain: diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td index c194e5c..0039d2f 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td @@ -1207,6 +1207,7 @@ def VRegSrc_96 : SrcReg9<VReg_96>; def VRegSrc_128: SrcReg9<VReg_128>; def VRegSrc_192: SrcReg9<VReg_192>; def VRegSrc_256: SrcReg9<VReg_256>; +def VRegSrc_384: SrcReg9<VReg_384>; def VRegSrc_512: SrcReg9<VReg_512>; def VRegSrc_1024: SrcReg9<VReg_1024>; def VRegOrLdsSrc_32 : SrcReg9<VRegOrLds_32>; diff --git a/llvm/lib/Target/AMDGPU/SISchedule.td b/llvm/lib/Target/AMDGPU/SISchedule.td index ef8faff..8eecb1c 100644 --- a/llvm/lib/Target/AMDGPU/SISchedule.td +++ b/llvm/lib/Target/AMDGPU/SISchedule.td @@ -464,6 +464,20 @@ def : InstRW<[WriteCopy], (instrs COPY)>; } // End SchedModel = GFX12SpeedModel +// Check if any matrix inputs are interpreted as f8 in an f8f6f4 +// wmma instruction. +def PredIsF8_WMMA_SCALE : SchedPredicate<[{ + TII->getNamedOperand(*MI, AMDGPU::OpName::matrix_a_fmt)->getImm() <= AMDGPU::WMMA::MATRIX_FMT_BF8 || + TII->getNamedOperand(*MI, AMDGPU::OpName::matrix_b_fmt)->getImm() <= AMDGPU::WMMA::MATRIX_FMT_BF8 +}]>; + +// If either matrix format is f8, the instruction takes 2x as many +// cycles. TODO: This isn't reflected in MCA. +def WriteWMMAScale_16X16X128_F8F6F4 : SchedWriteVariant<[ + SchedVar<PredIsF8_WMMA_SCALE, [WriteXDL4PassWMMA]>, + SchedVar<NoSchedPred, [WriteXDL2PassWMMA]> +]>; + multiclass GFX125xCommonWriteRes { let ReleaseAtCycles = [8] in @@ -495,6 +509,7 @@ def : InstRW<[WriteCopy], (instrs COPY)>; def : InstRW<[WriteXDL2PassWMMA], (instregex "^V_[S]*WMMA[C]*_.*_(FP8|BF8|BF16|F16)_w32")>; def : InstRW<[WriteXDL4PassWMMA], (instregex "^V_[S]*WMMA[C]*_.*_(IU8|IU4)_w32")>; +def : InstRW<[WriteWMMAScale_16X16X128_F8F6F4], (instregex "^V_WMMA_.*_16X16X128_F8F6F4.*_w32")>; def : InstRW<[Write4PassWMMA], (instregex "^V_WMMA_F32_16X16X4_F32_w32")>; def : InstRW<[WriteXDL2PassWMMA], (instregex "^V_WMMA.*_F32_32X16X128_F4")>; } // End GFX125xCommonWriteRes diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp index 7725881..9c6c374 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp @@ -598,6 +598,29 @@ const MFMA_F8F6F4_Info *getMFMA_F8F6F4_WithFormatArgs(unsigned CBSZ, return getMFMA_F8F6F4_InstWithNumRegs(SrcANumRegs, SrcBNumRegs, F8F8Opcode); } +uint8_t wmmaScaleF8F6F4FormatToNumRegs(unsigned Fmt) { + switch (Fmt) { + case WMMA::MATRIX_FMT_FP8: + case WMMA::MATRIX_FMT_BF8: + return 16; + case WMMA::MATRIX_FMT_FP6: + case WMMA::MATRIX_FMT_BF6: + return 12; + case WMMA::MATRIX_FMT_FP4: + return 8; + } + + llvm_unreachable("covered switch over wmma scale formats"); +} + +const MFMA_F8F6F4_Info *getWMMA_F8F6F4_WithFormatArgs(unsigned FmtA, + unsigned FmtB, + unsigned F8F8Opcode) { + uint8_t SrcANumRegs = wmmaScaleF8F6F4FormatToNumRegs(FmtA); + uint8_t SrcBNumRegs = wmmaScaleF8F6F4FormatToNumRegs(FmtB); + return getMFMA_F8F6F4_InstWithNumRegs(SrcANumRegs, SrcBNumRegs, F8F8Opcode); +} + unsigned getVOPDEncodingFamily(const MCSubtargetInfo &ST) { if (ST.hasFeature(AMDGPU::FeatureGFX1250Insts)) return SIEncodingFamily::GFX1250; diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h index c9d2c28..bde951b 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h @@ -627,6 +627,14 @@ const MFMA_F8F6F4_Info *getMFMA_F8F6F4_WithFormatArgs(unsigned CBSZ, unsigned BLGP, unsigned F8F8Opcode); +LLVM_READNONE +uint8_t wmmaScaleF8F6F4FormatToNumRegs(unsigned Fmt); + +LLVM_READONLY +const MFMA_F8F6F4_Info *getWMMA_F8F6F4_WithFormatArgs(unsigned FmtA, + unsigned FmtB, + unsigned F8F8Opcode); + LLVM_READONLY const GcnBufferFormatInfo *getGcnBufferFormatInfo(uint8_t BitsPerComp, uint8_t NumComponents, @@ -1423,7 +1431,8 @@ constexpr bool isShader(CallingConv::ID CC) { LLVM_READNONE constexpr bool isGraphics(CallingConv::ID CC) { - return isShader(CC) || CC == CallingConv::AMDGPU_Gfx; + return isShader(CC) || CC == CallingConv::AMDGPU_Gfx || + CC == CallingConv::AMDGPU_Gfx_WholeWave; } LLVM_READNONE diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUPALMetadata.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUPALMetadata.cpp index e464470..fd6253d 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUPALMetadata.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUPALMetadata.cpp @@ -44,6 +44,7 @@ static const char *getStageName(CallingConv::ID CC) { case CallingConv::AMDGPU_LS: return ".ls"; case CallingConv::AMDGPU_Gfx: + case CallingConv::AMDGPU_Gfx_WholeWave: llvm_unreachable("Callable shader has no hardware stage"); default: return ".cs"; diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 2e7f25b..aee2f2c 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -224,6 +224,12 @@ defm V_ALIGNBIT_B32 : VOP3Inst_t16_with_profiles <"v_alignbit_b32", fshr, null_frag>; defm V_ALIGNBYTE_B32 : VOP3Inst <"v_alignbyte_b32", VOP3_Profile<VOP_I32_I32_I32_I32>, int_amdgcn_alignbyte>; + +// In gfx9 and 10, opsel is allowed for V_ALIGNBIT_B32 and V_ALIGNBYTE_B32. +// Hardware uses opsel[1:0] to byte-select src2. Other opsel bits are ignored. +defm V_ALIGNBIT_B32_opsel : VOP3Inst <"v_alignbit_b32_opsel", VOP3_Profile<VOP_I32_I32_I32_I32, VOP3_OPSEL>>; +defm V_ALIGNBYTE_B32_opsel : VOP3Inst <"v_alignbyte_b32_opsel", VOP3_Profile<VOP_I32_I32_I32_I32, VOP3_OPSEL>>; + let True16Predicate = UseRealTrue16Insts in defm V_ALIGNBYTE_B32_t16 : VOP3Inst <"v_alignbyte_b32_t16", VOP3_Profile_True16<VOP_I32_I32_I32_I16, VOP3_OPSEL>>; let True16Predicate = UseFakeTrue16Insts in @@ -265,6 +271,16 @@ let SchedRW = [WriteDoubleAdd], FPDPRounding = 1 in { } // End SchedRW = [WriteDoubleAdd], FPDPRounding = 1 } // End isReMaterializable = 1 +let SubtargetPredicate = isGFX9GFX10 in +def : GCNPat < +(i32 (int_amdgcn_alignbyte (i32 (VOP3OpSelMods i32:$src0, i32:$src0_modifiers)), + (i32 (VOP3OpSelMods i32:$src1, i32:$src1_modifiers)), + (i32 (VOP3OpSelMods i32:$src2, i32:$src2_modifiers)))), +(V_ALIGNBYTE_B32_opsel_e64 i32:$src0_modifiers, VSrc_b32:$src0, + i32:$src1_modifiers, VSrc_b32:$src1, + i32:$src2_modifiers, VGPR_32:$src2) +>; + let True16Predicate = UseFakeTrue16Insts in def : GCNPat < (i32 (int_amdgcn_alignbyte (i32 (VOP3OpSelMods i32:$src0, i32:$src0_modifiers)), @@ -1954,6 +1970,9 @@ let AssemblerPredicate = isGFX10Only, DecoderNamespace = "GFX10" in { } } // End AssemblerPredicate = isGFX10Only, DecoderNamespace = "GFX10" +defm V_ALIGNBIT_B32_opsel : VOP3OpSel_Real_gfx10_with_name<0x14e, "V_ALIGNBIT_B32_opsel", "v_alignbit_b32">; +defm V_ALIGNBYTE_B32_opsel : VOP3OpSel_Real_gfx10_with_name<0x14f, "V_ALIGNBYTE_B32_opsel", "v_alignbyte_b32">; + defm V_READLANE_B32 : VOP3_Real_No_Suffix_gfx10<0x360>; let InOperandList = (ins SSrcOrLds_b32:$src0, SCSrc_b32:$src1, VGPR_32:$vdst_in) in { @@ -2104,8 +2123,8 @@ defm V_BFI_B32 : VOP3_Real_gfx6_gfx7_gfx10<0x14a>; defm V_FMA_F32 : VOP3_Real_gfx6_gfx7_gfx10<0x14b>; defm V_FMA_F64 : VOP3_Real_gfx6_gfx7_gfx10<0x14c>; defm V_LERP_U8 : VOP3_Real_gfx6_gfx7_gfx10<0x14d>; -defm V_ALIGNBIT_B32 : VOP3_Real_gfx6_gfx7_gfx10<0x14e>; -defm V_ALIGNBYTE_B32 : VOP3_Real_gfx6_gfx7_gfx10<0x14f>; +defm V_ALIGNBIT_B32 : VOP3_Real_gfx6_gfx7<0x14e>; +defm V_ALIGNBYTE_B32 : VOP3_Real_gfx6_gfx7<0x14f>; defm V_MULLIT_F32 : VOP3_Real_gfx6_gfx7_gfx10<0x150>; defm V_MIN3_F32 : VOP3_Real_gfx6_gfx7_gfx10<0x151>; defm V_MIN3_I32 : VOP3_Real_gfx6_gfx7_gfx10<0x152>; @@ -2248,6 +2267,17 @@ multiclass VOP3_Real_BITOP3_gfx9<bits<10> op, string AsmName, bit isSingle = 0> } } +// Instructions such as v_alignbyte_b32 allows op_sel in gfx9, but not in vi. +// The following is created to support that. +multiclass VOP3OpSel_Real_gfx9_with_name<bits<10> op, string opName, string AsmName> { + defvar psName = opName#"_e64"; + def _gfx9 : VOP3_Real<!cast<VOP3_Pseudo>(psName), SIEncodingFamily.VI>, // note: encoding family is VI + VOP3OpSel_gfx9 <op, !cast<VOP3_Pseudo>(psName).Pfl> { + VOP3_Pseudo ps = !cast<VOP3_Pseudo>(psName); + let AsmString = AsmName # ps.AsmOperands; + } +} + } // End AssemblerPredicate = isGFX9Only, DecoderNamespace = "GFX9" defm V_MAD_U64_U32 : VOP3be_Real_vi <0x1E8>; @@ -2267,8 +2297,10 @@ defm V_BFI_B32 : VOP3_Real_vi <0x1ca>; defm V_FMA_F32 : VOP3_Real_vi <0x1cb>; defm V_FMA_F64 : VOP3_Real_vi <0x1cc>; defm V_LERP_U8 : VOP3_Real_vi <0x1cd>; +let SubtargetPredicate = isGFX8Only in { defm V_ALIGNBIT_B32 : VOP3_Real_vi <0x1ce>; defm V_ALIGNBYTE_B32 : VOP3_Real_vi <0x1cf>; +} defm V_MIN3_F32 : VOP3_Real_vi <0x1d0>; defm V_MIN3_I32 : VOP3_Real_vi <0x1d1>; defm V_MIN3_U32 : VOP3_Real_vi <0x1d2>; @@ -2313,6 +2345,9 @@ defm V_INTERP_P2_LEGACY_F16 : VOP3Interp_F16_Real_gfx9 <0x276, "V_INTERP_P2_F16" defm V_MAD_LEGACY_U16 : VOP3_F16_Real_gfx9 <0x1eb, "V_MAD_U16", "v_mad_legacy_u16">; defm V_MAD_LEGACY_I16 : VOP3_F16_Real_gfx9 <0x1ec, "V_MAD_I16", "v_mad_legacy_i16">; +defm V_ALIGNBIT_B32_opsel : VOP3OpSel_Real_gfx9_with_name <0x1ce, "V_ALIGNBIT_B32_opsel", "v_alignbit_b32">; +defm V_ALIGNBYTE_B32_opsel : VOP3OpSel_Real_gfx9_with_name <0x1cf, "V_ALIGNBYTE_B32_opsel", "v_alignbyte_b32">; + defm V_MAD_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x203, "v_mad_f16">; defm V_MAD_U16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x204, "v_mad_u16">; defm V_MAD_I16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x205, "v_mad_i16">; diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index e51e957..9feea36 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -1318,13 +1318,15 @@ let WaveSizePredicate = isWave64 in { class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0, - bit _HasMatrixReuse = 0, bit _IsF4 = 0> + bit _HasMatrixFMT = 0, bit _HasMatrixReuse = 0, + bit _IsF4 = 0> : VOP3P_Profile<VOPProfile<ArgTy>> { bit IsIU = _IsIU; bit NoABMods = !or(_IsFP8BF8XF32, _IsF4); // No IMOD support for A and B bit IsXF32 = !and(_IsFP8BF8XF32, !eq(ArgTy[1], v8f32)); int IndexType = _IndexType; + let HasMatrixFMT = _HasMatrixFMT; let HasMatrixReuse = _HasMatrixReuse; bit HasIModOp = _Has_ImodOp; @@ -1422,7 +1424,8 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, !eq(IndexType, 8) : (ins IndexKey8bit:$index_key_8bit), !eq(IndexType, 16): (ins IndexKey16bit:$index_key_16bit), !eq(IndexType, 32): (ins IndexKey32bit:$index_key_32bit)); - + dag MatrixFMT = !if(HasMatrixFMT, (ins MatrixAFMT:$matrix_a_fmt, MatrixBFMT:$matrix_b_fmt), + (ins)); dag MatrixReuse = !if(HasMatrixReuse, (ins MatrixAReuse:$matrix_a_reuse, MatrixBReuse:$matrix_b_reuse), (ins)); dag Clamp = !if(HasClamp, (ins Clamp0:$clamp), (ins)); dag Neg = !cond(!and(NegLoAny, NegHiAny) : (ins neg_lo0:$neg_lo, neg_hi0:$neg_hi), @@ -1436,7 +1439,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, (ins VRegSrc_64:$src2), (ins VRegSrc_32:$src2)), IndexKey)), - MatrixReuse, Clamp, Neg); + MatrixFMT, MatrixReuse, Clamp, Neg); // asm @@ -1444,13 +1447,14 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, !eq(IndexType, 8) : "$index_key_8bit", !eq(IndexType, 16) : "$index_key_16bit", !eq(IndexType, 32) : "$index_key_32bit"); + string MatrxFMTAsm = !if(HasMatrixFMT, "$matrix_a_fmt$matrix_b_fmt", ""); string MatrixReuseAsm = !if(HasMatrixReuse, "$matrix_a_reuse$matrix_b_reuse", ""); string ClampAsm = !if(HasClamp, "$clamp", ""); string NegAsm = !cond(!and(NegLoAny, NegHiAny) : "$neg_lo$neg_hi", !and(NegLoAny, !not(NegHiAny)) : "$neg_lo", !and(!not(NegLoAny), !not(NegHiAny)) : ""); - let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrixReuseAsm#NegAsm#ClampAsm; + let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrxFMTAsm#MatrixReuseAsm#NegAsm#ClampAsm; // isel patterns bit IsAB_BF16_IMod0 = !and(IsAB_BF16, !not(HasIModOp)); @@ -1462,6 +1466,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, IsAB_F16_IMod0 : (ins (Src0VT (WMMAModsF16Neg Src0VT:$src0, i32:$src0_modifiers))), IsAB_BF16_IMod0 : (ins Src0VT:$src0), IsIU : (ins (VOP3PModsNeg i32:$src0_modifiers), Src0VT:$src0), + HasMatrixFMT : (ins timm:$matrix_a_fmt, Src0VT:$src0), NoABMods : (ins Src0VT:$src0)); dag Src0OutPat = !cond(IsAB_F32F64_IMod1 : (ins i32:$src0_modifiers, Src0VT:$src0), IsAB_F16BF16_IMod1 : (ins i32:$src0_modifiers, Src0VT:$src0), @@ -1474,6 +1479,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, IsAB_F16_IMod0 : (ins (Src1VT (WMMAModsF16Neg Src1VT:$src1, i32:$src1_modifiers))), IsAB_BF16_IMod0 : (ins Src1VT:$src1), IsIU : (ins (VOP3PModsNeg i32:$src1_modifiers), Src1VT:$src1), + HasMatrixFMT : (ins timm:$matrix_b_fmt, Src1VT:$src1), NoABMods : (ins Src1VT:$src1)); dag Src1OutPat = !cond(IsAB_F32F64_IMod1 : (ins i32:$src1_modifiers, Src1VT:$src1), IsAB_F16BF16_IMod1 : (ins i32:$src1_modifiers, Src1VT:$src1), @@ -1499,7 +1505,6 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, IsIUXF32 : (ins Src2VT:$src2), IsSWMMAC : (ins)); dag ClampPat = !if(HasClamp, (ins i1:$clamp), (ins)); - dag IndexInPat = !cond(!eq(IndexType, 0) : (ins i32:$src2), !eq(IndexType, 8) : (ins (i32 (SWMMACIndex8 i32:$src2, i32:$index_key_8bit))), !eq(IndexType, 16): (ins (i32 (SWMMACIndex16 i32:$src2, i32:$index_key_16bit))), @@ -1508,6 +1513,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, !eq(IndexType, 8) : (ins i32:$src2, i32:$index_key_8bit), !eq(IndexType, 16): (ins i32:$src2, i32:$index_key_16bit), !eq(IndexType, 32): (ins i64:$src2, i32:$index_key_32bit)); + dag MatrixFMTOutPat = !if(HasMatrixFMT, (ins i32:$matrix_a_fmt, i32:$matrix_b_fmt), (ins)); dag Src2InlineInPat = !con(!if(IsC_IMod1, (ins (VOP3PModsNegAbs i32:$src2_modifiers)), (ins)), (ins (Src2VT (WMMAVISrc Src2VT:$src2)))); dag Src2InlineOutPat = !con(!if(IsIUXF32, (ins), !if(IsC_IMod1, (ins i32:$src2_modifiers), (ins (i32 8)))), (ins Src2VT:$src2)); @@ -1515,7 +1521,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, dag MatrixReuseOutModPat = !if(HasMatrixReuse, (ins i1:$matrix_a_reuse, i1:$matrix_b_reuse), (ins)); dag WmmaInPat = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixReuseInPat, ClampPat); - dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixReuseOutModPat, ClampPat); + dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat); dag SwmmacInPat = !con(Src0InPat, Src1InPat, (ins Src2VT:$srcTiedDef), IndexInPat, MatrixReuseInPat, ClampPat); dag SwmmacOutPat = !con(Src0OutPat, Src1OutPat, (ins Src2VT:$srcTiedDef), IndexOutPat, MatrixReuseOutModPat, ClampPat); @@ -1523,7 +1529,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, // wmma pattern where src2 is inline imm uses _threeaddr pseudo, // can't use _twoaddr since it would violate src2 tied to vdst constraint. dag WmmaInlineInPat = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixReuseInPat, ClampPat); - dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixReuseOutModPat, ClampPat); + dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat); } def WMMAInstInfoTable : GenericTable { @@ -1632,26 +1638,45 @@ def F32_FP8BF8_SWMMAC_w64 : VOP3PWMMA_Profile<[v4f32, i32, v2i32, v4f32], 1, // *** IU4X32_SWMMAC_w64 lanes 0-31 will have 8xi4 remaining lanes are ignored // for matrix A, index is i16; Matrix B uses all lanes -def F64_F64X4_WMMA_w32 : VOP3PWMMA_Profile<[v8f64, v2f64, v2f64, v8f64], 0, 0, 0, 0, 1>; -def F32_F32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v2f32, v2f32, v8f32], 0, 0, 0, 0, 1, 1>; -def F32_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 1>; -def F32_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v16f16, v8f32], 0, 0, 0, 0, 1, 1>; -def F16_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v16f16, v16f16, v8f16], 0, 0, 0, 0, 1, 1>; -def BF16_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v16bf16, v8bf16], 0, 0, 0, 0, 1, 1>; -def BF16F32_BF16_WMMA_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 1>; -def F32_FP8BF8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1>; -def F32_FP8BF8X128_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1>; -def F16_FP8BF8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v8i32, v8f16], 0, 0, 0, 1, 1, 1>; -def F16_FP8BF8X128_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v16i32, v16i32, v8f16], 0, 0, 0, 1, 1, 1>; -def F32_32X16X128_F4_WMMA_w32 : VOP3PWMMA_Profile<[v16f32, v16i32, v8i32, v16f32], 0, 0, 0, 0, 1, 0, 1>; -def I32_IU8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v8i32, v8i32], 0, 0, 1, 0, 1, 1>; -def F32_F16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v32f16, v8f32], 1, 16, 0, 0, 1, 1>; -def F32_BF16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v32bf16, v8f32], 1, 16, 0, 0, 1, 1>; -def F16_F16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v16f16, v32f16, v8f16], 1, 16, 0, 0, 1, 1>; -def BF16_BF16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v32bf16, v8bf16], 1, 16, 0, 0, 1, 1>; -def F32_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 1, 32, 0, 1, 1, 1>; -def F16_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v16i32, v8f16], 1, 32, 0, 1, 1, 1>; -def I32_IU8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v16i32, v8i32], 1, 32, 1, 0, 1, 1>; +def F32_F32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v2f32, v2f32, v8f32], 0, 0, 0, 0, 1, 0, 1>; +def F32_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 1>; +def F32_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v16f16, v8f32], 0, 0, 0, 0, 1, 0, 1>; +def F16_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v16f16, v16f16, v8f16], 0, 0, 0, 0, 1, 0, 1>; +def BF16_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v16bf16, v8bf16], 0, 0, 0, 0, 1, 0, 1>; +def BF16F32_BF16_WMMA_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 1>; +def F32_FP8BF8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 0, 1>; +def F32_FP8BF8X128_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 0, 1>; +def F16_FP8BF8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v8i32, v8f16], 0, 0, 0, 1, 1, 0, 1>; +def F16_FP8BF8X128_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v16i32, v16i32, v8f16], 0, 0, 0, 1, 1, 0, 1>; +def F32_32X16X128_F4_WMMA_w32 : VOP3PWMMA_Profile<[v16f32, v16i32, v8i32, v16f32], 0, 0, 0, 0, 1, 0, 0, 1>; +def I32_IU8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v8i32, v8i32], 0, 0, 1, 0, 1, 0, 1>; +def F32_F16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v32f16, v8f32], 1, 16, 0, 0, 1, 0, 1>; +def F32_BF16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v32bf16, v8f32], 1, 16, 0, 0, 1, 0, 1>; +def F16_F16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v16f16, v32f16, v8f16], 1, 16, 0, 0, 1, 0, 1>; +def BF16_BF16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v32bf16, v8bf16], 1, 16, 0, 0, 1, 0, 1>; +def F32_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 1, 32, 0, 1, 1, 0, 1>; +def F16_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v16i32, v8f16], 1, 32, 0, 1, 1, 0, 1>; +def I32_IU8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v16i32, v8i32], 1, 32, 1, 0, 1, 0, 1>; + +multiclass WMMA_F8F6F4_Profiles<bit HasMatrixReuse> { + def _f8_f8_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; + def _f8_f6_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; + def _f8_f4_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; + def _f6_f8_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; + def _f6_f6_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; + def _f6_f4_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; + def _f4_f8_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; + def _f4_f6_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; + def _f4_f4_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; +} + +defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles<0>; + +multiclass WMMAInst_SrcFormats_mc<string OpName, string Profile> { + foreach I = ["f8_f8", "f8_f6", "f8_f4", "f6_f8", "f6_f6", "f6_f4", "f4_f8", "f4_f6", "f4_f4"] in { + defm _#I#_w32 : WMMAInstGFX12<OpName # "_" # I # "_w32", !cast<VOP3PWMMA_Profile>(Profile # "_" # I # "_w32"), "_w32">; + } +} let WaveSizePredicate = isWave32 in { let SubtargetPredicate = isGFX125xOnly in { @@ -1697,6 +1722,8 @@ defm V_SWMMAC_I32_16X16X128_IU8_w32 : SWMMACInstGFX12<"v_swmmac_i32_16x16x12 defm V_SWMMAC_F32_16X16X64_F16_w32 : SWMMACInstGFX12<"v_swmmac_f32_16x16x64_f16", F32_F16X64_SWMMAC_w32, "_w32">; defm V_SWMMAC_F16_16X16X64_F16_w32 : SWMMACInstGFX12<"v_swmmac_f16_16x16x64_f16", F16_F16X64_SWMMAC_w32, "_w32">; +defm V_WMMA_F32_16X16X128_F8F6F4 : WMMAInst_SrcFormats_mc<"v_wmma_f32_16x16x128_f8f6f4", "F32_16X16X128_F8F6F4">; + } // End is_wmma_xdl = 1. } // End SubtargetPredicate = isGFX125xOnly @@ -1854,6 +1881,10 @@ let SubtargetPredicate = isGFX125xOnly in { defm : WMMAPat<"V_WMMA_F32_16X16X128_BF8_BF8_w32", int_amdgcn_wmma_f32_16x16x128_bf8_bf8, F32_FP8BF8X128_WMMA_w32>; defm : WMMAPat<"V_WMMA_F32_32X16X128_F4_w32", int_amdgcn_wmma_f32_32x16x128_f4, F32_32X16X128_F4_WMMA_w32>; + foreach I = ["f8_f8", "f8_f6", "f8_f4", "f6_f8", "f6_f6", "f6_f4", "f4_f8", "f4_f6", "f4_f4"] in { + defm : WMMAPat<"V_WMMA_F32_16X16X128_F8F6F4_" # I # "_w32", int_amdgcn_wmma_f32_16x16x128_f8f6f4, !cast<VOP3PWMMA_Profile>("F32_16X16X128_F8F6F4_" # I # "_w32")>; + } + def : SWMMACPat<V_SWMMAC_F32_16X16X64_BF16_w32_twoaddr, int_amdgcn_swmmac_f32_16x16x64_bf16, F32_BF16X64_SWMMAC_w32>; def : SWMMACPat<V_SWMMAC_BF16_16X16X64_BF16_w32_twoaddr, int_amdgcn_swmmac_bf16_16x16x64_bf16, BF16_BF16X64_SWMMAC_w32>; def : SWMMACPat<V_SWMMAC_BF16F32_16X16X64_BF16_w32_twoaddr, int_amdgcn_swmmac_bf16f32_16x16x64_bf16, F32_BF16X64_SWMMAC_w32>; @@ -1912,17 +1943,22 @@ multiclass VOP3P_Real_Base<GFXGen Gen, bits<8> op, string backing_ps_name = NAME class VOP3PeWmma<bits<8> op, VOPProfile P, VOP3PWMMA_Profile WMMAP> : VOP3Pe_gfx11_gfx12<op, P>{ + // opsel - let Inst{11} = !cond(!eq(WMMAP.IndexType, 0) : 0, + let Inst{11} = !cond(WMMAP.HasMatrixFMT : matrix_a_fmt{0}, + !eq(WMMAP.IndexType, 0) : 0, !eq(WMMAP.IndexType, 8) : index_key_8bit{0}, !eq(WMMAP.IndexType, 16) : index_key_16bit{0}, !eq(WMMAP.IndexType, 32) : index_key_32bit{0}); - let Inst{12} = !if(!eq(WMMAP.IndexType, 8), index_key_8bit{1}, 0); - let Inst{13} = !if(WMMAP.HasMatrixReuse, matrix_a_reuse, 0); + let Inst{12} = !if(WMMAP.HasMatrixFMT, matrix_a_fmt{1}, + !if(!eq(WMMAP.IndexType, 8), index_key_8bit{1}, 0)); + let Inst{13} = !if (WMMAP.HasMatrixFMT, matrix_a_fmt{2}, + !if(WMMAP.HasMatrixReuse, matrix_a_reuse, 0)); // opsel_hi - let Inst{59} = 1; - let Inst{60} = 1; - let Inst{14} = !if(WMMAP.HasMatrixReuse, matrix_b_reuse, 1); + let Inst{59} = !if (WMMAP.HasMatrixFMT, matrix_b_fmt{0}, 1); + let Inst{60} = !if (WMMAP.HasMatrixFMT, matrix_b_fmt{1}, 1); + let Inst{14} = !if (WMMAP.HasMatrixFMT, matrix_b_fmt{2}, + !if(WMMAP.HasMatrixReuse, matrix_b_reuse, 1)); // neg_lo let Inst{61} = !if(WMMAP.NegLo01, src0_modifiers{0}, 0); let Inst{62} = !if(WMMAP.NegLo01, src1_modifiers{0}, 0); @@ -1961,6 +1997,24 @@ multiclass VOP3P_Real_WMMA_gfx1250 <bits<8> op, VOP3PWMMA_Profile WMMAP> { } } +multiclass VOP3P_Real_WMMA_F8F6F4_gfx1250<bits<8> op, VOP3PWMMA_Profile WMMAP> { + defvar PS = !cast<VOP3P_Pseudo>(NAME # "_twoaddr"); + defvar asmName = !substr(PS.Mnemonic, 0, !sub(!size(PS.Mnemonic), !size("_f8_f8_w32"))); + defvar psName = !substr(NAME, 0, !sub(!size(PS.Mnemonic), !size("_f8_f8_w32"))); + let AsmString = asmName # PS.AsmOperands in + defm NAME : VOP3P_Real_WMMA_gfx1250<op, WMMAP>, + MFMA_F8F6F4_WithSizeTable_Helper<PS, psName # "_f8_f8_w32_twoaddr_gfx1250">; +} + +multiclass VOP3P_Real_WMMA_gfx1250_SrcFormats<bits<8> op, string WMMAP> { + defm _f8_f8_w32 : VOP3P_Real_WMMA_F8F6F4_gfx1250<op, !cast<VOP3PWMMA_Profile>(WMMAP # "_f8_f8_w32")>; + foreach I = ["f8_f6", "f8_f4", "f6_f8", "f6_f6", "f6_f4", "f4_f8", "f4_f6", "f4_f4"] in { + let isAsmParserOnly = true in { // Disable ambiguous disassembly. + defm _#I#_w32 : VOP3P_Real_WMMA_F8F6F4_gfx1250<op, !cast<VOP3PWMMA_Profile>(WMMAP # "_" # I # "_w32")>; + } + } +} + defm V_WMMA_F32_16X16X16_F16_w32 : VOP3P_Real_WMMA_gfx12 <0x040, F32_F16_WMMA_w32>; defm V_WMMA_F32_16X16X16_BF16_w32 : VOP3P_Real_WMMA_gfx12 <0x041, F32_BF16_WMMA_w32>; defm V_WMMA_F16_16X16X16_F16_w32 : VOP3P_Real_WMMA_gfx12 <0x042, F16_F16_WMMA_w32>; @@ -2035,6 +2089,8 @@ defm V_WMMA_F16_16X16X128_BF8_FP8_w32 : VOP3P_Real_WMMA_gfx1250 <0x086, F16_FP8B defm V_WMMA_F16_16X16X128_BF8_BF8_w32 : VOP3P_Real_WMMA_gfx1250 <0x087, F16_FP8BF8X128_WMMA_w32>; defm V_WMMA_F32_32X16X128_F4_w32 : VOP3P_Real_WMMA_gfx1250 <0x088, F32_32X16X128_F4_WMMA_w32>; +defm V_WMMA_F32_16X16X128_F8F6F4 : VOP3P_Real_WMMA_gfx1250_SrcFormats<0x033, "F32_16X16X128_F8F6F4">; + defm V_SWMMAC_F32_16X16X64_F16_w32 : VOP3P_Real_WMMA_gfx1250 <0x065, F32_F16X64_SWMMAC_w32>; defm V_SWMMAC_F32_16X16X64_BF16_w32 : VOP3P_Real_WMMA_gfx1250 <0x066, F32_BF16X64_SWMMAC_w32>; defm V_SWMMAC_F16_16X16X64_F16_w32 : VOP3P_Real_WMMA_gfx1250 <0x067, F16_F16X64_SWMMAC_w32>; diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td index a25ebdf..c21e2d3 100644 --- a/llvm/lib/Target/AMDGPU/VOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td @@ -453,6 +453,8 @@ class VOP3Pe_Base { bits<2> index_key_8bit; bits<1> index_key_16bit; bits<1> index_key_32bit; + bits<3> matrix_a_fmt; + bits<3> matrix_b_fmt; bits<1> matrix_a_reuse; bits<1> matrix_b_reuse; } diff --git a/llvm/lib/Target/ARM/MCTargetDesc/ARMELFStreamer.cpp b/llvm/lib/Target/ARM/MCTargetDesc/ARMELFStreamer.cpp index eaba6fe..a7a9911 100644 --- a/llvm/lib/Target/ARM/MCTargetDesc/ARMELFStreamer.cpp +++ b/llvm/lib/Target/ARM/MCTargetDesc/ARMELFStreamer.cpp @@ -593,7 +593,7 @@ public: getContext().reportError(Loc, "relocated expression must be 32-bit"); return; } - getOrCreateDataFragment(); + getCurrentFragment(); } emitDataMappingSymbol(); @@ -1207,7 +1207,7 @@ inline void ARMELFStreamer::SwitchToExIdxSection(const MCSymbol &FnStart) { } void ARMELFStreamer::EmitFixup(const MCExpr *Expr, MCFixupKind Kind) { - MCFragment *Frag = getOrCreateDataFragment(); + MCFragment *Frag = getCurrentFragment(); Frag->addFixup(MCFixup::create(Frag->getContents().size(), Expr, Kind)); } @@ -1295,7 +1295,7 @@ void ARMELFStreamer::EmitPersonalityFixup(StringRef Name) { MCSymbolRefExpr::create(PersonalitySym, ARM::S_ARM_NONE, getContext()); visitUsedExpr(*PersonalityRef); - MCFragment *DF = getOrCreateDataFragment(); + MCFragment *DF = getCurrentFragment(); DF->addFixup( MCFixup::create(DF->getContents().size(), PersonalityRef, FK_Data_4)); } diff --git a/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp b/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp index db09738..128cc0b 100644 --- a/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp +++ b/llvm/lib/Target/AVR/MCTargetDesc/AVRAsmBackend.cpp @@ -514,19 +514,7 @@ bool AVRAsmBackend::forceRelocation(const MCFragment &F, const MCFixup &Fixup, return false; case AVR::fixup_7_pcrel: - case AVR::fixup_13_pcrel: { - uint64_t Offset = Target.getConstant(); - uint64_t Size = AVRAsmBackend::getFixupKindInfo(Fixup.getKind()).TargetSize; - - // If the jump is too large to encode it, fall back to a relocation. - // - // Note that trying to actually link that relocation *would* fail, but the - // hopes are that the module we're currently compiling won't be actually - // linked to the final binary. - return !adjust::adjustRelativeBranch(Size, Fixup, Offset, - getContext().getSubtargetInfo()); - } - + case AVR::fixup_13_pcrel: case AVR::fixup_call: return true; } diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 2378664..a31fa57 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -4560,6 +4560,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT, llvm_unreachable("Unexpected node type for vXi1 sign extension"); } +static SDValue +performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const LoongArchSubtarget &Subtarget) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue Src = N->getOperand(0); + EVT SrcVT = Src.getValueType(); + + if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse()) + return SDValue(); + + bool UseLASX; + unsigned Opc = ISD::DELETED_NODE; + EVT CmpVT = Src.getOperand(0).getValueType(); + EVT EltVT = CmpVT.getVectorElementType(); + + if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128) + UseLASX = false; + else if (Subtarget.has32S() && Subtarget.hasExtLASX() && + CmpVT.getSizeInBits() == 256) + UseLASX = true; + else + return SDValue(); + + SDValue SrcN1 = Src.getOperand(1); + switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) { + default: + break; + case ISD::SETEQ: + // x == 0 => not (vmsknez.b x) + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ; + break; + case ISD::SETGT: + // x > -1 => vmskgez.b x + if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; + break; + case ISD::SETGE: + // x >= 0 => vmskgez.b x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; + break; + case ISD::SETLT: + // x < 0 => vmskltz.{b,h,w,d} x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && + (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || + EltVT == MVT::i64)) + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; + break; + case ISD::SETLE: + // x <= -1 => vmskltz.{b,h,w,d} x + if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && + (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || + EltVT == MVT::i64)) + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; + break; + case ISD::SETNE: + // x != 0 => vmsknez.b x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ; + break; + } + + if (Opc == ISD::DELETED_NODE) + return SDValue(); + + SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0)); + EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements()); + V = DAG.getZExtOrTrunc(V, DL, T); + return DAG.getBitcast(VT, V); +} + static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const LoongArchSubtarget &Subtarget) { @@ -4574,110 +4648,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG, if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1) return SDValue(); - unsigned Opc = ISD::DELETED_NODE; // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible + SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget); + if (Res) + return Res; + + // Generate vXi1 using [X]VMSKLTZ + MVT SExtVT; + unsigned Opc; + bool UseLASX = false; + bool PropagateSExt = false; + if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) { - bool UseLASX; EVT CmpVT = Src.getOperand(0).getValueType(); - EVT EltVT = CmpVT.getVectorElementType(); - - if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128) - UseLASX = false; - else if (Subtarget.has32S() && Subtarget.hasExtLASX() && - CmpVT.getSizeInBits() <= 256) - UseLASX = true; - else + if (CmpVT.getSizeInBits() > 256) return SDValue(); - - SDValue SrcN1 = Src.getOperand(1); - switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) { - default: - break; - case ISD::SETEQ: - // x == 0 => not (vmsknez.b x) - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ; - break; - case ISD::SETGT: - // x > -1 => vmskgez.b x - if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; - break; - case ISD::SETGE: - // x >= 0 => vmskgez.b x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; - break; - case ISD::SETLT: - // x < 0 => vmskltz.{b,h,w,d} x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && - (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || - EltVT == MVT::i64)) - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - break; - case ISD::SETLE: - // x <= -1 => vmskltz.{b,h,w,d} x - if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && - (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || - EltVT == MVT::i64)) - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - break; - case ISD::SETNE: - // x != 0 => vmsknez.b x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ; - break; - } } - // Generate vXi1 using [X]VMSKLTZ - if (Opc == ISD::DELETED_NODE) { - MVT SExtVT; - bool UseLASX = false; - bool PropagateSExt = false; - switch (SrcVT.getSimpleVT().SimpleTy) { - default: - return SDValue(); - case MVT::v2i1: - SExtVT = MVT::v2i64; - break; - case MVT::v4i1: - SExtVT = MVT::v4i32; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v4i64; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v8i1: - SExtVT = MVT::v8i16; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v8i32; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v16i1: - SExtVT = MVT::v16i8; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v16i16; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v32i1: - SExtVT = MVT::v32i8; + switch (SrcVT.getSimpleVT().SimpleTy) { + default: + return SDValue(); + case MVT::v2i1: + SExtVT = MVT::v2i64; + break; + case MVT::v4i1: + SExtVT = MVT::v4i32; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v4i64; UseLASX = true; - break; - }; - if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX()) - return SDValue(); - Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL) - : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src); - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - } else { - Src = Src.getOperand(0); - } + PropagateSExt = true; + } + break; + case MVT::v8i1: + SExtVT = MVT::v8i16; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v8i32; + UseLASX = true; + PropagateSExt = true; + } + break; + case MVT::v16i1: + SExtVT = MVT::v16i8; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v16i16; + UseLASX = true; + PropagateSExt = true; + } + break; + case MVT::v32i1: + SExtVT = MVT::v32i8; + UseLASX = true; + break; + }; + if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX())) + return SDValue(); + Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL) + : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src); + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src); EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements()); diff --git a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp index 7b9f115..8fa72bc 100644 --- a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp +++ b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.cpp @@ -177,74 +177,6 @@ void LoongArchAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, } } -// Linker relaxation may change code size. We have to insert Nops -// for .align directive when linker relaxation enabled. So then Linker -// could satisfy alignment by removing Nops. -// The function returns the total Nops Size we need to insert. -bool LoongArchAsmBackend::shouldInsertExtraNopBytesForCodeAlign( - const MCAlignFragment &AF, unsigned &Size) { - // Calculate Nops Size only when linker relaxation enabled. - if (!AF.getSubtargetInfo()->hasFeature(LoongArch::FeatureRelax)) - return false; - - // Ignore alignment if MaxBytesToEmit is less than the minimum Nop size. - const unsigned MinNopLen = 4; - if (AF.getMaxBytesToEmit() < MinNopLen) - return false; - Size = AF.getAlignment().value() - MinNopLen; - return AF.getAlignment() > MinNopLen; -} - -// We need to insert R_LARCH_ALIGN relocation type to indicate the -// position of Nops and the total bytes of the Nops have been inserted -// when linker relaxation enabled. -// The function inserts fixup_loongarch_align fixup which eventually will -// transfer to R_LARCH_ALIGN relocation type. -// The improved R_LARCH_ALIGN requires symbol index. The lowest 8 bits of -// addend represent alignment and the other bits of addend represent the -// maximum number of bytes to emit. The maximum number of bytes is zero -// means ignore the emit limit. -bool LoongArchAsmBackend::shouldInsertFixupForCodeAlign(MCAssembler &Asm, - MCAlignFragment &AF) { - // Insert the fixup only when linker relaxation enabled. - if (!AF.getSubtargetInfo()->hasFeature(LoongArch::FeatureRelax)) - return false; - - // Calculate total Nops we need to insert. If there are none to insert - // then simply return. - unsigned InsertedNopBytes; - if (!shouldInsertExtraNopBytesForCodeAlign(AF, InsertedNopBytes)) - return false; - - MCSection *Sec = AF.getParent(); - MCContext &Ctx = getContext(); - const MCExpr *Dummy = MCConstantExpr::create(0, Ctx); - MCFixup Fixup = MCFixup::create(0, Dummy, ELF::R_LARCH_ALIGN); - unsigned MaxBytesToEmit = AF.getMaxBytesToEmit(); - - auto createExtendedValue = [&]() { - const MCSymbolRefExpr *MCSym = getSecToAlignSym()[Sec]; - if (MCSym == nullptr) { - // Define a marker symbol at the section with an offset of 0. - MCSymbol *Sym = Ctx.createNamedTempSymbol("la-relax-align"); - Sym->setFragment(&*Sec->getBeginSymbol()->getFragment()); - Asm.registerSymbol(*Sym); - MCSym = MCSymbolRefExpr::create(Sym, Ctx); - getSecToAlignSym()[Sec] = MCSym; - } - return MCValue::get(&MCSym->getSymbol(), nullptr, - MaxBytesToEmit << 8 | Log2(AF.getAlignment())); - }; - - uint64_t FixedValue = 0; - MCValue Value = MaxBytesToEmit >= InsertedNopBytes - ? MCValue::get(InsertedNopBytes) - : createExtendedValue(); - Asm.getWriter().recordRelocation(AF, Fixup, Value, FixedValue); - - return true; -} - bool LoongArchAsmBackend::shouldForceRelocation(const MCFixup &Fixup, const MCValue &Target) { switch (Fixup.getKind()) { @@ -279,6 +211,53 @@ getRelocPairForSize(unsigned Size) { } } +// Check if an R_LARCH_ALIGN relocation is needed for an alignment directive. +// If conditions are met, compute the padding size and create a fixup encoding +// the padding size in the addend. If MaxBytesToEmit is smaller than the padding +// size, the fixup encodes MaxBytesToEmit in the higher bits and references a +// per-section marker symbol. +bool LoongArchAsmBackend::relaxAlign(MCFragment &F, unsigned &Size) { + // Use default handling unless linker relaxation is enabled and the + // MaxBytesToEmit >= the nop size. + if (!F.getSubtargetInfo()->hasFeature(LoongArch::FeatureRelax)) + return false; + const unsigned MinNopLen = 4; + unsigned MaxBytesToEmit = F.getAlignMaxBytesToEmit(); + if (MaxBytesToEmit < MinNopLen) + return false; + + Size = F.getAlignment().value() - MinNopLen; + if (F.getAlignment() <= MinNopLen) + return false; + + MCContext &Ctx = getContext(); + const MCExpr *Expr = nullptr; + if (MaxBytesToEmit >= Size) { + Expr = MCConstantExpr::create(Size, getContext()); + } else { + MCSection *Sec = F.getParent(); + const MCSymbolRefExpr *SymRef = getSecToAlignSym()[Sec]; + if (SymRef == nullptr) { + // Define a marker symbol at the section with an offset of 0. + MCSymbol *Sym = Ctx.createNamedTempSymbol("la-relax-align"); + Sym->setFragment(&*Sec->getBeginSymbol()->getFragment()); + Asm->registerSymbol(*Sym); + SymRef = MCSymbolRefExpr::create(Sym, Ctx); + getSecToAlignSym()[Sec] = SymRef; + } + Expr = MCBinaryExpr::createAdd( + SymRef, + MCConstantExpr::create((MaxBytesToEmit << 8) | Log2(F.getAlignment()), + Ctx), + Ctx); + } + MCFixup Fixup = + MCFixup::create(0, Expr, FirstLiteralRelocationKind + ELF::R_LARCH_ALIGN); + F.setVarFixups({Fixup}); + F.getParent()->setLinkerRelaxable(); + return true; +} + std::pair<bool, bool> LoongArchAsmBackend::relaxLEB128(MCFragment &F, int64_t &Value) const { const MCExpr &Expr = F.getLEBValue(); @@ -434,7 +413,7 @@ bool LoongArchAsmBackend::isPCRelFixupResolved(const MCSymbol *SymA, // Otherwise, check if the offset between the symbol and fragment is fully // resolved, unaffected by linker-relaxable fragments (e.g. instructions or - // offset-affected MCAlignFragment). Complements the generic + // offset-affected FT_Align fragments). Complements the generic // isSymbolRefDifferenceFullyResolvedImpl. if (!PCRelTemp) PCRelTemp = getContext().createTempSymbol(); diff --git a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.h b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.h index b32ba06..3d929fc 100644 --- a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.h +++ b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchAsmBackend.h @@ -45,20 +45,13 @@ public: MutableArrayRef<char> Data, uint64_t Value, bool IsResolved) override; - // Return Size with extra Nop Bytes for alignment directive in code section. - bool shouldInsertExtraNopBytesForCodeAlign(const MCAlignFragment &AF, - unsigned &Size) override; - - // Insert target specific fixup type for alignment directive in code section. - bool shouldInsertFixupForCodeAlign(MCAssembler &Asm, - MCAlignFragment &AF) override; - bool shouldForceRelocation(const MCFixup &Fixup, const MCValue &Target); std::optional<MCFixupKind> getFixupKind(StringRef Name) const override; MCFixupKindInfo getFixupKindInfo(MCFixupKind Kind) const override; + bool relaxAlign(MCFragment &F, unsigned &Size) override; bool relaxDwarfLineAddr(MCFragment &F, bool &WasRelaxed) const override; bool relaxDwarfCFA(MCFragment &F, bool &WasRelaxed) const override; std::pair<bool, bool> relaxLEB128(MCFragment &F, diff --git a/llvm/lib/Target/Mips/MCTargetDesc/MipsTargetStreamer.cpp b/llvm/lib/Target/Mips/MCTargetDesc/MipsTargetStreamer.cpp index b89d689..feb4eb3 100644 --- a/llvm/lib/Target/Mips/MCTargetDesc/MipsTargetStreamer.cpp +++ b/llvm/lib/Target/Mips/MCTargetDesc/MipsTargetStreamer.cpp @@ -1033,45 +1033,40 @@ MCELFStreamer &MipsTargetELFStreamer::getStreamer() { } void MipsTargetELFStreamer::emitGPRel32Value(const MCExpr *Value) { - MCFragment *DF = getStreamer().getOrCreateDataFragment(); - DF->addFixup(MCFixup::create(DF->getContents().size(), Value, - Mips::fixup_Mips_GPREL32)); - DF->appendContents(4, 0); + auto &S = getStreamer(); + S.addFixup(Value, Mips::fixup_Mips_GPREL32); + S.appendContents(4, 0); } void MipsTargetELFStreamer::emitGPRel64Value(const MCExpr *Value) { - MCFragment *DF = getStreamer().getOrCreateDataFragment(); - DF->addFixup(MCFixup::create(DF->getContents().size(), Value, - Mips::fixup_Mips_GPREL32)); - DF->appendContents(8, 0); + auto &S = getStreamer(); + // fixup_Mips_GPREL32 desginates R_MIPS_GPREL32+R_MIPS_64 on MIPS64. + S.addFixup(Value, Mips::fixup_Mips_GPREL32); + S.appendContents(8, 0); } void MipsTargetELFStreamer::emitDTPRel32Value(const MCExpr *Value) { - MCFragment *DF = getStreamer().getOrCreateDataFragment(); - DF->addFixup(MCFixup::create(DF->getContents().size(), Value, - Mips::fixup_Mips_DTPREL32)); - DF->appendContents(4, 0); + auto &S = getStreamer(); + S.addFixup(Value, Mips::fixup_Mips_DTPREL32); + S.appendContents(4, 0); } void MipsTargetELFStreamer::emitDTPRel64Value(const MCExpr *Value) { - MCFragment *DF = getStreamer().getOrCreateDataFragment(); - DF->addFixup(MCFixup::create(DF->getContents().size(), Value, - Mips::fixup_Mips_DTPREL64)); - DF->appendContents(8, 0); + auto &S = getStreamer(); + S.addFixup(Value, Mips::fixup_Mips_DTPREL64); + S.appendContents(8, 0); } void MipsTargetELFStreamer::emitTPRel32Value(const MCExpr *Value) { - MCFragment *DF = getStreamer().getOrCreateDataFragment(); - DF->addFixup(MCFixup::create(DF->getContents().size(), Value, - Mips::fixup_Mips_TPREL32)); - DF->appendContents(4, 0); + auto &S = getStreamer(); + S.addFixup(Value, Mips::fixup_Mips_TPREL32); + S.appendContents(4, 0); } void MipsTargetELFStreamer::emitTPRel64Value(const MCExpr *Value) { - MCFragment *DF = getStreamer().getOrCreateDataFragment(); - DF->addFixup(MCFixup::create(DF->getContents().size(), Value, - Mips::fixup_Mips_TPREL64)); - DF->appendContents(8, 0); + auto &S = getStreamer(); + S.addFixup(Value, Mips::fixup_Mips_TPREL64); + S.appendContents(8, 0); } void MipsTargetELFStreamer::emitDirectiveSetMicroMips() { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 77784be..7883acc 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -4006,7 +4006,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); @@ -4029,6 +4032,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( return true; } + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(4); + return true; + } + + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::v4i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(16); + return true; + } + case Intrinsic::nvvm_atomic_add_gen_f_cta: case Intrinsic::nvvm_atomic_add_gen_f_sys: case Intrinsic::nvvm_atomic_add_gen_i_cta: diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index f329f48..0a00220 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4758,7 +4758,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op> !and(!eq(op, "ldmatrix"), !eq(ptx_elt_type, "b8x16.b4x16_p64"), - !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); + !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>], + + !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"), + !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>], + + !and(!eq(op, "stmatrix"), + !eq(ptx_elt_type, "b8"), + !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -5039,6 +5046,42 @@ defset list<WMMA_INSTR> LDMATRIXs = { } // transposed } // defset +// +// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 +// +class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space> + : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>, + Requires<Frag.Predicates> { + // Build PatFrag that only matches particular address space. + dag PFOperands = !con((ops node:$dst), + !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names)); + PatFrag IntrFrag = PatFrag<PFOperands, + !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)), + !cond(!eq(Space, ".shared"): AS_match.shared, + true: AS_match.generic)>; + // Build AS-constrained pattern. + let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret; + let OutOperandList = (outs); + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + let AsmString = "stmatrix.sync.aligned." + # Frag.geom + # "." # Frag.frag + # !if(Transposed, ".trans", "") + # Space + # "." # Frag.ptx_elt_type + # " [$dst], " # Frag.regstring # ";"; +} + +// Create all stmatrix variants +defset list<WMMA_INSTR> STMATRIXs = { + foreach transposed = [false, true] in {foreach space = [".shared", ""] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in + if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then + def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>; + } // space + } // transposed +} // defset + // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. @@ -5049,7 +5092,7 @@ class MMA_PAT<WMMA_INSTR wi> Requires<wi.Predicates>; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in def : MMA_PAT<mma>; multiclass MAPA<string suffix, Intrinsic Intr> { diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp index f76f8b3..2c37c3b 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp @@ -302,6 +302,28 @@ void RISCVAsmBackend::relaxInstruction(MCInst &Inst, Inst = std::move(Res); } +// Check if an R_RISCV_ALIGN relocation is needed for an alignment directive. +// If conditions are met, compute the padding size and create a fixup encoding +// the padding size in the addend. +bool RISCVAsmBackend::relaxAlign(MCFragment &F, unsigned &Size) { + // Use default handling unless linker relaxation is enabled and the alignment + // is larger than the nop size. + const MCSubtargetInfo *STI = F.getSubtargetInfo(); + if (!STI->hasFeature(RISCV::FeatureRelax)) + return false; + unsigned MinNopLen = STI->hasFeature(RISCV::FeatureStdExtZca) ? 2 : 4; + if (F.getAlignment() <= MinNopLen) + return false; + + Size = F.getAlignment().value() - MinNopLen; + auto *Expr = MCConstantExpr::create(Size, getContext()); + MCFixup Fixup = + MCFixup::create(0, Expr, FirstLiteralRelocationKind + ELF::R_RISCV_ALIGN); + F.setVarFixups({Fixup}); + F.getParent()->setLinkerRelaxable(); + return true; +} + bool RISCVAsmBackend::relaxDwarfLineAddr(MCFragment &F, bool &WasRelaxed) const { MCContext &C = getContext(); @@ -637,7 +659,7 @@ bool RISCVAsmBackend::isPCRelFixupResolved(const MCSymbol *SymA, // Otherwise, check if the offset between the symbol and fragment is fully // resolved, unaffected by linker-relaxable fragments (e.g. instructions or - // offset-affected MCAlignFragment). Complements the generic + // offset-affected FT_Align fragments). Complements the generic // isSymbolRefDifferenceFullyResolvedImpl. if (!PCRelTemp) PCRelTemp = getContext().createTempSymbol(); @@ -887,55 +909,6 @@ void RISCVAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, } } -// Linker relaxation may change code size. We have to insert Nops -// for .align directive when linker relaxation enabled. So then Linker -// could satisfy alignment by removing Nops. -// The function return the total Nops Size we need to insert. -bool RISCVAsmBackend::shouldInsertExtraNopBytesForCodeAlign( - const MCAlignFragment &AF, unsigned &Size) { - // Calculate Nops Size only when linker relaxation enabled. - const MCSubtargetInfo *STI = AF.getSubtargetInfo(); - if (!STI->hasFeature(RISCV::FeatureRelax)) - return false; - - unsigned MinNopLen = STI->hasFeature(RISCV::FeatureStdExtZca) ? 2 : 4; - - if (AF.getAlignment() <= MinNopLen) { - return false; - } else { - Size = AF.getAlignment().value() - MinNopLen; - return true; - } -} - -// We need to insert R_RISCV_ALIGN relocation type to indicate the -// position of Nops and the total bytes of the Nops have been inserted -// when linker relaxation enabled. -// The function insert fixup_riscv_align fixup which eventually will -// transfer to R_RISCV_ALIGN relocation type. -bool RISCVAsmBackend::shouldInsertFixupForCodeAlign(MCAssembler &Asm, - MCAlignFragment &AF) { - // Insert the fixup only when linker relaxation enabled. - const MCSubtargetInfo *STI = AF.getSubtargetInfo(); - if (!STI->hasFeature(RISCV::FeatureRelax)) - return false; - - // Calculate total Nops we need to insert. If there are none to insert - // then simply return. - unsigned Count; - if (!shouldInsertExtraNopBytesForCodeAlign(AF, Count) || (Count == 0)) - return false; - - MCContext &Ctx = getContext(); - const MCExpr *Dummy = MCConstantExpr::create(0, Ctx); - MCFixup Fixup = MCFixup::create(0, Dummy, ELF::R_RISCV_ALIGN); - - uint64_t FixedValue = 0; - MCValue NopBytes = MCValue::get(Count); - Asm.getWriter().recordRelocation(AF, Fixup, NopBytes, FixedValue); - return true; -} - std::unique_ptr<MCObjectTargetWriter> RISCVAsmBackend::createObjectTargetWriter() const { return createRISCVELFObjectWriter(OSABI, Is64Bit); diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h index 8c10fbe..d97d632 100644 --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h @@ -38,14 +38,6 @@ public: const MCTargetOptions &Options); ~RISCVAsmBackend() override = default; - // Return Size with extra Nop Bytes for alignment directive in code section. - bool shouldInsertExtraNopBytesForCodeAlign(const MCAlignFragment &AF, - unsigned &Size) override; - - // Insert target specific fixup type for alignment directive in code section. - bool shouldInsertFixupForCodeAlign(MCAssembler &Asm, - MCAlignFragment &AF) override; - std::optional<bool> evaluateFixup(const MCFragment &, MCFixup &, MCValue &, uint64_t &) override; bool addReloc(const MCFragment &, const MCFixup &, const MCValue &, @@ -73,6 +65,7 @@ public: void relaxInstruction(MCInst &Inst, const MCSubtargetInfo &STI) const override; + bool relaxAlign(MCFragment &F, unsigned &Size) override; bool relaxDwarfLineAddr(MCFragment &F, bool &WasRelaxed) const override; bool relaxDwarfCFA(MCFragment &F, bool &WasRelaxed) const override; std::pair<bool, bool> relaxLEB128(MCFragment &LF, diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index cfec46d2..a541c2f 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3106,6 +3106,25 @@ bool RISCVDAGToDAGISel::SelectAddrRegRegScale(SDValue Addr, return true; } +bool RISCVDAGToDAGISel::SelectAddrRegZextRegScale(SDValue Addr, + unsigned MaxShiftAmount, + unsigned Bits, SDValue &Base, + SDValue &Index, + SDValue &Scale) { + if (!SelectAddrRegRegScale(Addr, MaxShiftAmount, Base, Index, Scale)) + return false; + + if (Index.getOpcode() == ISD::AND) { + auto *C = dyn_cast<ConstantSDNode>(Index.getOperand(1)); + if (C && C->getZExtValue() == maskTrailingOnes<uint64_t>(Bits)) { + Index = Index.getOperand(0); + return true; + } + } + + return false; +} + bool RISCVDAGToDAGISel::SelectAddrRegReg(SDValue Addr, SDValue &Base, SDValue &Offset) { if (Addr.getOpcode() != ISD::ADD) diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h index 72e2f96..ee3a86e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -59,19 +59,14 @@ public: return SelectAddrRegRegScale(Addr, MaxShift, Base, Index, Scale); } + bool SelectAddrRegZextRegScale(SDValue Addr, unsigned MaxShiftAmount, + unsigned Bits, SDValue &Base, SDValue &Index, + SDValue &Scale); + template <unsigned MaxShift, unsigned Bits> bool SelectAddrRegZextRegScale(SDValue Addr, SDValue &Base, SDValue &Index, SDValue &Scale) { - if (SelectAddrRegRegScale(Addr, MaxShift, Base, Index, Scale)) { - if (Index.getOpcode() == ISD::AND) { - auto *C = dyn_cast<ConstantSDNode>(Index.getOperand(1)); - if (C && C->getZExtValue() == maskTrailingOnes<uint64_t>(Bits)) { - Index = Index.getOperand(0); - return true; - } - } - } - return false; + return SelectAddrRegZextRegScale(Addr, MaxShift, Bits, Base, Index, Scale); } bool SelectAddrRegReg(SDValue Addr, SDValue &Base, SDValue &Offset); diff --git a/llvm/lib/Target/RISCV/RISCVInstrFormats.td b/llvm/lib/Target/RISCV/RISCVInstrFormats.td index e23001a..d9c6101 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrFormats.td +++ b/llvm/lib/Target/RISCV/RISCVInstrFormats.td @@ -174,6 +174,7 @@ class EltDeps<bit vl, bit mask> { def EltDepsNone : EltDeps<vl=0, mask=0>; def EltDepsVL : EltDeps<vl=1, mask=0>; +def EltDepsMask : EltDeps<vl=0, mask=1>; def EltDepsVLMask : EltDeps<vl=1, mask=1>; class EEW <bits<2> val> { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoV.td b/llvm/lib/Target/RISCV/RISCVInstrInfoV.td index 5d13a87..33c7138 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoV.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoV.td @@ -1642,7 +1642,7 @@ def VFIRST_M : RVInstV<0b010000, 0b10001, OPMVV, (outs GPR:$vd), def : MnemonicAlias<"vpopc.m", "vcpop.m">; -let Constraints = "@earlyclobber $vd", RVVConstraint = Iota, ElementsDependOn = EltDepsVLMask in { +let Constraints = "@earlyclobber $vd", RVVConstraint = Iota, ElementsDependOn = EltDepsMask in { let DestEEW = EEW1 in { // vmsbf.m set-before-first mask bit @@ -1655,7 +1655,7 @@ defm VMSOF_M : VMSFS_MV_V<"vmsof.m", 0b010100, 0b00010>; // Vector Iota Instruction defm VIOTA_M : VIOTA_MV_V<"viota.m", 0b010100, 0b10000>; -} // Constraints = "@earlyclobber $vd", RVVConstraint = Iota, ElementsDependOn = EltDepsVLMask +} // Constraints = "@earlyclobber $vd", RVVConstraint = Iota, ElementsDependOn = EltDepsMask // Vector Element Index Instruction let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td index c7cb6e2..f391300 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td @@ -1377,9 +1377,9 @@ let Predicates = [HasVendorXqciac, IsRV32] in { def : Pat<(i32 (add GPRNoX0:$rd, (mul GPRNoX0:$rs1, simm12:$imm12))), (QC_MULIADD GPRNoX0:$rd, GPRNoX0:$rs1, simm12:$imm12)>; def : Pat<(i32 (add_like_non_imm12 (shl GPRNoX0:$rs1, uimm5gt3:$imm), GPRNoX0:$rs2)), - (QC_SHLADD GPRNoX0:$rs2, GPRNoX0:$rs1, uimm5gt3:$imm)>; + (QC_SHLADD GPRNoX0:$rs1, GPRNoX0:$rs2, uimm5gt3:$imm)>; def : Pat<(i32 (riscv_shl_add GPRNoX0:$rs1, uimm5gt3:$imm, GPRNoX0:$rs2)), - (QC_SHLADD GPRNoX0:$rs2, GPRNoX0:$rs1, uimm5gt3:$imm)>; + (QC_SHLADD GPRNoX0:$rs1, GPRNoX0:$rs2, uimm5gt3:$imm)>; } // Predicates = [HasVendorXqciac, IsRV32] /// Simple arithmetic operations diff --git a/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp b/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp index dd68a55..6de870c 100644 --- a/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp +++ b/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp @@ -131,24 +131,40 @@ static bool getMemOperands(unsigned Factor, VectorType *VTy, Type *XLenTy, : Constant::getAllOnesValue(XLenTy); return true; } - auto *VPLdSt = cast<VPIntrinsic>(I); - assert((VPLdSt->getIntrinsicID() == Intrinsic::vp_load || - VPLdSt->getIntrinsicID() == Intrinsic::vp_store) && - "Unexpected intrinsic"); - Ptr = VPLdSt->getMemoryPointerParam(); - Alignment = VPLdSt->getPointerAlignment().value_or( - DL.getABITypeAlign(VTy->getElementType())); + if (auto *VPLdSt = dyn_cast<VPIntrinsic>(I)) { + assert((VPLdSt->getIntrinsicID() == Intrinsic::vp_load || + VPLdSt->getIntrinsicID() == Intrinsic::vp_store) && + "Unexpected intrinsic"); + Ptr = VPLdSt->getMemoryPointerParam(); + Alignment = VPLdSt->getPointerAlignment().value_or( + DL.getABITypeAlign(VTy->getElementType())); + + assert(Mask && "vp.load and vp.store needs a mask!"); + + Value *WideEVL = VPLdSt->getVectorLengthParam(); + // Conservatively check if EVL is a multiple of factor, otherwise some + // (trailing) elements might be lost after the transformation. + if (!isMultipleOfN(WideEVL, I->getDataLayout(), Factor)) + return false; - assert(Mask && "vp.load and vp.store needs a mask!"); + auto *FactorC = ConstantInt::get(WideEVL->getType(), Factor); + VL = Builder.CreateZExt(Builder.CreateExactUDiv(WideEVL, FactorC), XLenTy); + return true; + } + auto *II = cast<IntrinsicInst>(I); + assert(II->getIntrinsicID() == Intrinsic::masked_load && + "Unexpected intrinsic"); + Ptr = II->getOperand(0); + Alignment = cast<ConstantInt>(II->getArgOperand(1))->getAlignValue(); - Value *WideEVL = VPLdSt->getVectorLengthParam(); - // Conservatively check if EVL is a multiple of factor, otherwise some - // (trailing) elements might be lost after the transformation. - if (!isMultipleOfN(WideEVL, I->getDataLayout(), Factor)) + if (!isa<UndefValue>(II->getOperand(3))) return false; - auto *FactorC = ConstantInt::get(WideEVL->getType(), Factor); - VL = Builder.CreateZExt(Builder.CreateExactUDiv(WideEVL, FactorC), XLenTy); + assert(Mask && "masked.load needs a mask!"); + + VL = isa<FixedVectorType>(VTy) + ? Builder.CreateElementCount(XLenTy, VTy->getElementCount()) + : Constant::getAllOnesValue(XLenTy); return true; } diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp index 28d6403..3b19c34 100644 --- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp +++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp @@ -48,6 +48,8 @@ using namespace llvm; STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); STATISTIC(NumTransformedToWInstrs, "Number of instructions transformed to W-ops"); +STATISTIC(NumTransformedToNonWInstrs, + "Number of instructions transformed to non-W-ops"); static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", cl::desc("Disable removal of sext.w"), @@ -67,10 +69,9 @@ public: bool runOnMachineFunction(MachineFunction &MF) override; bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, const RISCVSubtarget &ST, MachineRegisterInfo &MRI); - bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, - const RISCVSubtarget &ST, MachineRegisterInfo &MRI); - bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, - const RISCVSubtarget &ST, MachineRegisterInfo &MRI); + bool canonicalizeWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, + const RISCVSubtarget &ST, + MachineRegisterInfo &MRI); void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); @@ -721,45 +722,39 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, return MadeChange; } -bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, - const RISCVInstrInfo &TII, - const RISCVSubtarget &ST, - MachineRegisterInfo &MRI) { +// Strips or adds W suffixes to eligible instructions depending on the +// subtarget preferences. +bool RISCVOptWInstrs::canonicalizeWSuffixes(MachineFunction &MF, + const RISCVInstrInfo &TII, + const RISCVSubtarget &ST, + MachineRegisterInfo &MRI) { + bool ShouldStripW = !(DisableStripWSuffix || ST.preferWInst()); + bool ShouldPreferW = ST.preferWInst(); bool MadeChange = false; - for (MachineBasicBlock &MBB : MF) { - for (MachineInstr &MI : MBB) { - unsigned Opc; - switch (MI.getOpcode()) { - default: - continue; - case RISCV::ADDW: Opc = RISCV::ADD; break; - case RISCV::ADDIW: Opc = RISCV::ADDI; break; - case RISCV::MULW: Opc = RISCV::MUL; break; - case RISCV::SLLIW: Opc = RISCV::SLLI; break; - } - if (hasAllWUsers(MI, ST, MRI)) { - MI.setDesc(TII.get(Opc)); - MadeChange = true; - } - } - } - - return MadeChange; -} - -bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF, - const RISCVInstrInfo &TII, - const RISCVSubtarget &ST, - MachineRegisterInfo &MRI) { - bool MadeChange = false; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { - unsigned WOpc; - // TODO: Add more? - switch (MI.getOpcode()) { + std::optional<unsigned> WOpc; + std::optional<unsigned> NonWOpc; + unsigned OrigOpc = MI.getOpcode(); + switch (OrigOpc) { default: continue; + case RISCV::ADDW: + NonWOpc = RISCV::ADD; + break; + case RISCV::ADDIW: + NonWOpc = RISCV::ADDI; + break; + case RISCV::MULW: + NonWOpc = RISCV::MUL; + break; + case RISCV::SLLIW: + NonWOpc = RISCV::SLLI; + break; + case RISCV::SUBW: + NonWOpc = RISCV::SUB; + break; case RISCV::ADD: WOpc = RISCV::ADDW; break; @@ -773,7 +768,7 @@ bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF, WOpc = RISCV::MULW; break; case RISCV::SLLI: - // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits + // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits. if (MI.getOperand(2).getImm() >= 32) continue; WOpc = RISCV::SLLIW; @@ -784,19 +779,30 @@ bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF, break; } - if (hasAllWUsers(MI, ST, MRI)) { + if (ShouldStripW && NonWOpc.has_value() && hasAllWUsers(MI, ST, MRI)) { + LLVM_DEBUG(dbgs() << "Replacing " << MI); + MI.setDesc(TII.get(NonWOpc.value())); + LLVM_DEBUG(dbgs() << " with " << MI); + ++NumTransformedToNonWInstrs; + MadeChange = true; + continue; + } + // LWU is always converted to LW when possible as 1) LW is compressible + // and 2) it helps minimise differences vs RV32. + if ((ShouldPreferW || OrigOpc == RISCV::LWU) && WOpc.has_value() && + hasAllWUsers(MI, ST, MRI)) { LLVM_DEBUG(dbgs() << "Replacing " << MI); - MI.setDesc(TII.get(WOpc)); + MI.setDesc(TII.get(WOpc.value())); MI.clearFlag(MachineInstr::MIFlag::NoSWrap); MI.clearFlag(MachineInstr::MIFlag::NoUWrap); MI.clearFlag(MachineInstr::MIFlag::IsExact); LLVM_DEBUG(dbgs() << " with " << MI); ++NumTransformedToWInstrs; MadeChange = true; + continue; } } } - return MadeChange; } @@ -813,12 +819,6 @@ bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { bool MadeChange = false; MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); - - if (!(DisableStripWSuffix || ST.preferWInst())) - MadeChange |= stripWSuffixes(MF, TII, ST, MRI); - - if (ST.preferWInst()) - MadeChange |= appendWSuffixes(MF, TII, ST, MRI); - + MadeChange |= canonicalizeWSuffixes(MF, TII, ST, MRI); return MadeChange; } diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index e656e8b..15bd346 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -33,6 +33,7 @@ namespace { class RISCVVLOptimizer : public MachineFunctionPass { const MachineRegisterInfo *MRI; const MachineDominatorTree *MDT; + const TargetInstrInfo *TII; public: static char ID; @@ -1291,7 +1292,8 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return false; } - assert(!RISCVII::elementsDependOnVL(RISCV::getRVVMCOpcode(MI.getOpcode())) && + assert(!RISCVII::elementsDependOnVL( + TII->get(RISCV::getRVVMCOpcode(MI.getOpcode())).TSFlags) && "Instruction shouldn't be supported if elements depend on VL"); assert(MI.getOperand(0).isReg() && @@ -1495,6 +1497,8 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { if (!ST.hasVInstructions()) return false; + TII = ST.getInstrInfo(); + // For each instruction that defines a vector, compute what VL its // downstream users demand. for (MachineBasicBlock *MBB : post_order(&MF)) { diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 6608b3f..d4fa62a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -296,6 +296,8 @@ private: bool selectImageWriteIntrinsic(MachineInstr &I) const; bool selectResourceGetPointer(Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectModf(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; // Utilities std::pair<Register, bool> @@ -3235,6 +3237,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_discard: { return selectDiscard(ResVReg, ResType, I); } + case Intrinsic::modf: { + return selectModf(ResVReg, ResType, I); + } default: { std::string DiagMsg; raw_string_ostream OS(DiagMsg); @@ -4018,6 +4023,83 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectModf(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + // llvm.modf has a single arg --the number to be decomposed-- and returns a + // struct { restype, restype }, while OpenCLLIB::modf has two args --the + // number to be decomposed and a pointer--, returns the fractional part and + // the integral part is stored in the pointer argument. Therefore, we can't + // use directly the OpenCLLIB::modf intrinsic. However, we can do some + // scaffolding to make it work. The idea is to create an alloca instruction + // to get a ptr, pass this ptr to OpenCL::modf, and then load the value + // from this ptr to place it in the struct. llvm.modf returns the fractional + // part as the first element of the result, and the integral part as the + // second element of the result. + + // At this point, the return type is not a struct anymore, but rather two + // independent elements of SPIRVResType. We can get each independent element + // from I.getDefs() or I.getOperands(). + if (STI.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { + MachineIRBuilder MIRBuilder(I); + // Get pointer type for alloca variable. + const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType( + ResType, MIRBuilder, SPIRV::StorageClass::Function); + // Create new register for the pointer type of alloca variable. + Register PtrTyReg = + MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); + MIRBuilder.getMRI()->setType( + PtrTyReg, + LLT::pointer(storageClassToAddressSpace(SPIRV::StorageClass::Function), + GR.getPointerSize())); + // Assign SPIR-V type of the pointer type of the alloca variable to the + // new register. + GR.assignSPIRVTypeToVReg(PtrType, PtrTyReg, MIRBuilder.getMF()); + MachineBasicBlock &EntryBB = I.getMF()->front(); + MachineBasicBlock::iterator VarPos = + getFirstValidInstructionInsertPoint(EntryBB); + auto AllocaMIB = + BuildMI(EntryBB, VarPos, I.getDebugLoc(), TII.get(SPIRV::OpVariable)) + .addDef(PtrTyReg) + .addUse(GR.getSPIRVTypeID(PtrType)) + .addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function)); + Register Variable = AllocaMIB->getOperand(0).getReg(); + // Modf must have 4 operands, the first two are the 2 parts of the result, + // the third is the operand, and the last one is the floating point value. + assert(I.getNumOperands() == 4 && + "Expected 4 operands for modf instruction"); + MachineBasicBlock &BB = *I.getParent(); + // Create the OpenCLLIB::modf instruction. + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std)) + .addImm(CL::modf) + .setMIFlags(I.getFlags()) + .add(I.getOperand(3)) // Floating point value. + .addUse(Variable); // Pointer to integral part. + // Assign the integral part stored in the ptr to the second element of the + // result. + Register IntegralPartReg = I.getOperand(1).getReg(); + if (IntegralPartReg.isValid()) { + // Load the value from the pointer to integral part. + auto LoadMIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoad)) + .addDef(IntegralPartReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Variable); + return LoadMIB.constrainAllUses(TII, TRI, RBI); + } + + return MIB.constrainAllUses(TII, TRI, RBI); + } else if (STI.canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450)) { + assert(false && "GLSL::Modf is deprecated."); + // FIXME: GL::Modf is deprecated, use Modfstruct instead. + return false; + } + return false; +} + // Generate the instructions to load 3-element vector builtin input // IDs/Indices. // Like: GlobalInvocationId, LocalInvocationId, etc.... diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp index 2bffbf7..6766bd8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -380,7 +380,7 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { bool Changed = false; const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F); for (BasicBlock &BB : *F) { - for (Instruction &I : BB) { + for (Instruction &I : make_early_inc_range(BB)) { auto Call = dyn_cast<CallInst>(&I); if (!Call) continue; @@ -408,12 +408,16 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { if (!STI.isShader()) { Changed |= toSpvOverloadedIntrinsic( II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1}); + } else { + II->eraseFromParent(); } break; case Intrinsic::lifetime_end: if (!STI.isShader()) { Changed |= toSpvOverloadedIntrinsic( II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1}); + } else { + II->eraseFromParent(); } break; case Intrinsic::ptr_annotation: diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 768efb9..416d811 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -995,4 +995,27 @@ unsigned getArrayComponentCount(const MachineRegisterInfo *MRI, return foldImm(ResType->getOperand(2), MRI); } +MachineBasicBlock::iterator +getFirstValidInstructionInsertPoint(MachineBasicBlock &BB) { + // Find the position to insert the OpVariable instruction. + // We will insert it after the last OpFunctionParameter, if any, or + // after OpFunction otherwise. + MachineBasicBlock::iterator VarPos = BB.begin(); + while (VarPos != BB.end() && VarPos->getOpcode() != SPIRV::OpFunction) { + ++VarPos; + } + // Advance VarPos to the next instruction after OpFunction, it will either + // be an OpFunctionParameter, so that we can start the next loop, or the + // position to insert the OpVariable instruction. + ++VarPos; + while (VarPos != BB.end() && + VarPos->getOpcode() == SPIRV::OpFunctionParameter) { + ++VarPos; + } + // VarPos is now pointing at after the last OpFunctionParameter, if any, + // or after OpFunction, if no parameters. + return VarPos != BB.end() && VarPos->getOpcode() == SPIRV::OpLabel ? ++VarPos + : VarPos; +} + } // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index d732188..45c520a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -506,6 +506,8 @@ MachineInstr *getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI); int64_t foldImm(const MachineOperand &MO, const MachineRegisterInfo *MRI); unsigned getArrayComponentCount(const MachineRegisterInfo *MRI, const MachineInstr *ResType); +MachineBasicBlock::iterator +getFirstValidInstructionInsertPoint(MachineBasicBlock &BB); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index bf2e04c..09b8864 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -46,6 +46,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( : TargetLowering(TM), Subtarget(&STI) { auto MVTPtr = Subtarget->hasAddr64() ? MVT::i64 : MVT::i32; + // Set the load count for memcmp expand optimization + MaxLoadsPerMemcmp = 8; + MaxLoadsPerMemcmpOptSize = 4; + // Booleans always contain 0 or 1. setBooleanContents(ZeroOrOneBooleanContent); // Except in SIMD vectors @@ -2935,6 +2939,25 @@ performVectorExtendToFPCombine(SDNode *N, } static SDValue +performVectorNonNegToFPCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + auto &DAG = DCI.DAG; + + SDNodeFlags Flags = N->getFlags(); + SDValue Op0 = N->getOperand(0); + EVT VT = N->getValueType(0); + + // Optimize uitofp to sitofp when the sign bit is known to be zero. + // Depending on the target (runtime) backend, this might be performance + // neutral (e.g. AArch64) or a significant improvement (e.g. x86_64). + if (VT.isVector() && (Flags.hasNonNeg() || DAG.SignBitIsZero(Op0))) { + return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, Op0); + } + + return SDValue(); +} + +static SDValue performVectorExtendCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { auto &DAG = DCI.DAG; assert(N->getOpcode() == ISD::SIGN_EXTEND || @@ -3515,6 +3538,9 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, case ISD::ZERO_EXTEND: return performVectorExtendCombine(N, DCI); case ISD::UINT_TO_FP: + if (auto ExtCombine = performVectorExtendToFPCombine(N, DCI)) + return ExtCombine; + return performVectorNonNegToFPCombine(N, DCI); case ISD::SINT_TO_FP: return performVectorExtendToFPCombine(N, DCI); case ISD::FP_TO_SINT_SAT: diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp index 4f15999..52e7065 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp @@ -141,6 +141,21 @@ InstructionCost WebAssemblyTTIImpl::getCastInstrCost( return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } +WebAssemblyTTIImpl::TTI::MemCmpExpansionOptions +WebAssemblyTTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const { + TTI::MemCmpExpansionOptions Options; + + Options.AllowOverlappingLoads = true; + + // TODO: Teach WebAssembly backend about load v128. + + Options.LoadSizes.append({8, 4, 2, 1}); + Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize); + Options.NumLoadsPerBlock = Options.MaxNumLoads; + + return Options; +} + InstructionCost WebAssemblyTTIImpl::getMemoryOpCost( unsigned Opcode, Type *Ty, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, TTI::OperandValueInfo OpInfo, diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h index d83b8d1..c915eeb0 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h @@ -73,6 +73,10 @@ public: getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr) const override; + + TTI::MemCmpExpansionOptions + enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const override; + InstructionCost getMemoryOpCost( unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp index 3d060c6..e213923 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86AsmBackend.cpp @@ -127,7 +127,6 @@ class X86AsmBackend : public MCAsmBackend { unsigned PrevInstOpcode = 0; MCBoundaryAlignFragment *PendingBA = nullptr; std::pair<MCFragment *, size_t> PrevInstPosition; - bool IsRightAfterData = false; uint8_t determinePaddingPrefix(const MCInst &Inst) const; bool isMacroFused(const MCInst &Cmp, const MCInst &Jcc) const; @@ -156,10 +155,13 @@ public: AlignBranchType = X86AlignBranchKindLoc; if (X86PadMaxPrefixSize.getNumOccurrences()) TargetPrefixMax = X86PadMaxPrefixSize; + + AllowAutoPadding = + AlignBoundary != Align(1) && AlignBranchType != X86::AlignBranchNone; + AllowEnhancedRelaxation = + AllowAutoPadding && TargetPrefixMax != 0 && X86PadForBranchAlign; } - bool allowAutoPadding() const override; - bool allowEnhancedRelaxation() const override; void emitInstructionBegin(MCObjectStreamer &OS, const MCInst &Inst, const MCSubtargetInfo &STI); void emitInstructionEnd(MCObjectStreamer &OS, const MCInst &Inst); @@ -365,14 +367,6 @@ static bool hasVariantSymbol(const MCInst &MI) { return false; } -bool X86AsmBackend::allowAutoPadding() const { - return (AlignBoundary != Align(1) && AlignBranchType != X86::AlignBranchNone); -} - -bool X86AsmBackend::allowEnhancedRelaxation() const { - return allowAutoPadding() && TargetPrefixMax != 0 && X86PadForBranchAlign; -} - /// X86 has certain instructions which enable interrupts exactly one /// instruction *after* the instruction which stores to SS. Return true if the /// given instruction may have such an interrupt delay slot. @@ -447,7 +441,7 @@ bool X86AsmBackend::canPadInst(const MCInst &Inst, MCObjectStreamer &OS) const { // semantic. return false; - if (IsRightAfterData) + if (isRightAfterData(OS.getCurrentFragment(), PrevInstPosition)) // If this instruction follows any data, there is no clear // instruction boundary, inserting a nop/prefix would change semantic. return false; @@ -484,13 +478,26 @@ bool X86AsmBackend::needAlign(const MCInst &Inst) const { (AlignBranchType & X86::AlignBranchIndirect)); } +void X86_MC::emitInstruction(MCObjectStreamer &S, const MCInst &Inst, + const MCSubtargetInfo &STI) { + bool AutoPadding = S.getAllowAutoPadding(); + if (LLVM_LIKELY(!AutoPadding && !X86PadForAlign)) { + S.MCObjectStreamer::emitInstruction(Inst, STI); + return; + } + + auto &Backend = static_cast<X86AsmBackend &>(S.getAssembler().getBackend()); + Backend.emitInstructionBegin(S, Inst, STI); + S.MCObjectStreamer::emitInstruction(Inst, STI); + Backend.emitInstructionEnd(S, Inst); +} + /// Insert BoundaryAlignFragment before instructions to align branches. void X86AsmBackend::emitInstructionBegin(MCObjectStreamer &OS, const MCInst &Inst, const MCSubtargetInfo &STI) { - // Used by canPadInst. Done here, because in emitInstructionEnd, the current - // fragment will have changed. - IsRightAfterData = - isRightAfterData(OS.getCurrentFragment(), PrevInstPosition); + bool CanPadInst = canPadInst(Inst, OS); + if (CanPadInst) + OS.getCurrentFragment()->setAllowAutoPadding(true); if (!canPadBranches(OS)) return; @@ -504,7 +511,7 @@ void X86AsmBackend::emitInstructionBegin(MCObjectStreamer &OS, // we call canPadInst (not cheap) twice. However, in the common case, we can // avoid unnecessary calls to that, as this is otherwise only used for // relaxable fragments. - if (!canPadInst(Inst, OS)) + if (!CanPadInst) return; if (PendingBA && PendingBA->getNext() == OS.getCurrentFragment()) { @@ -542,11 +549,8 @@ void X86AsmBackend::emitInstructionBegin(MCObjectStreamer &OS, /// Set the last fragment to be aligned for the BoundaryAlignFragment. void X86AsmBackend::emitInstructionEnd(MCObjectStreamer &OS, const MCInst &Inst) { - MCFragment *CF = OS.getCurrentFragment(); - if (CF->getKind() == MCFragment::FT_Relaxable) - CF->setAllowAutoPadding(canPadInst(Inst, OS)); - // Update PrevInstOpcode here, canPadInst() reads that. + MCFragment *CF = OS.getCurrentFragment(); PrevInstOpcode = Inst.getOpcode(); PrevInstPosition = std::make_pair(CF, getSizeForInstFragment(CF)); @@ -567,11 +571,10 @@ void X86AsmBackend::emitInstructionEnd(MCObjectStreamer &OS, // DataFragment, so that we can get the size of instructions later in // MCAssembler::relaxBoundaryAlign. The easiest way is to insert a new empty // DataFragment. - OS.insert(OS.getContext().allocFragment<MCFragment>()); + OS.newFragment(); // Update the maximum alignment on the current section if necessary. - MCSection *Sec = OS.getCurrentSectionOnly(); - Sec->ensureMinAlignment(AlignBoundary); + CF->getParent()->ensureMinAlignment(AlignBoundary); } std::optional<MCFixupKind> X86AsmBackend::getFixupKind(StringRef Name) const { @@ -923,13 +926,11 @@ bool X86AsmBackend::finishLayout(const MCAssembler &Asm) const { continue; } - const uint64_t OrigSize = Asm.computeFragmentSize(F); - // To keep the effects local, prefer to relax instructions closest to // the align directive. This is purely about human understandability // of the resulting code. If we later find a reason to expand // particular instructions over others, we can adjust. - unsigned RemainingSize = OrigSize; + unsigned RemainingSize = Asm.computeFragmentSize(F) - F.getFixedSize(); while (!Relaxable.empty() && RemainingSize != 0) { auto &RF = *Relaxable.pop_back_val(); // Give the backend a chance to play any tricks it wishes to increase @@ -1542,14 +1543,6 @@ public: }; } // end anonymous namespace -void X86_MC::emitInstruction(MCObjectStreamer &S, const MCInst &Inst, - const MCSubtargetInfo &STI) { - auto &Backend = static_cast<X86AsmBackend &>(S.getAssembler().getBackend()); - Backend.emitInstructionBegin(S, Inst, STI); - S.MCObjectStreamer::emitInstruction(Inst, STI); - Backend.emitInstructionEnd(S, Inst); -} - void X86ELFStreamer::emitInstruction(const MCInst &Inst, const MCSubtargetInfo &STI) { X86_MC::emitInstruction(*this, Inst, STI); |