diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU')
32 files changed, 952 insertions, 321 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td index 8a0c4ac..d84f512 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPU.td +++ b/llvm/lib/Target/AMDGPU/AMDGPU.td @@ -1160,6 +1160,12 @@ def FeatureTanhInsts : SubtargetFeature<"tanh-insts", "Has v_tanh_f32/f16 instructions" >; +def FeatureTensorCvtLutInsts : SubtargetFeature<"tensor-cvt-lut-insts", + "HasTensorCvtLutInsts", + "true", + "Has v_perm_pk16* instructions" +>; + def FeatureTransposeLoadF4F6Insts : SubtargetFeature<"transpose-load-f4f6-insts", "HasTransposeLoadF4F6Insts", "true", @@ -1359,6 +1365,13 @@ def FeatureXF32Insts : SubtargetFeature<"xf32-insts", "v_mfma_f32_16x16x8_xf32 and v_mfma_f32_32x32x4_xf32" >; +def FeatureGloballyAddressableScratch : SubtargetFeature< + "globally-addressable-scratch", + "HasGloballyAddressableScratch", + "true", + "FLAT instructions can access scratch memory for any thread in any wave" +>; + // FIXME: Remove after all users are migrated to attribute. def FeatureDynamicVGPR : SubtargetFeature <"dynamic-vgpr", "DynamicVGPR", @@ -2030,6 +2043,7 @@ def FeatureISAVersion12_50 : FeatureSet< FeatureDPPSrc1SGPR, FeatureBitOp3Insts, FeatureTanhInsts, + FeatureTensorCvtLutInsts, FeatureTransposeLoadF4F6Insts, FeatureBF16TransInsts, FeatureBF16ConversionInsts, @@ -2048,6 +2062,7 @@ def FeatureISAVersion12_50 : FeatureSet< FeatureAtomicFMinFMaxF64FlatInsts, FeatureFlatBufferGlobalAtomicFaddF64Inst, FeatureMemoryAtomicFAddF32DenormalSupport, + FeatureGloballyAddressableScratch, FeatureKernargPreload, FeatureVmemPrefInsts, FeatureLshlAddU64Inst, @@ -2785,6 +2800,9 @@ def HasBitOp3Insts : Predicate<"Subtarget->hasBitOp3Insts()">, def HasTanhInsts : Predicate<"Subtarget->hasTanhInsts()">, AssemblerPredicate<(all_of FeatureTanhInsts)>; +def HasTensorCvtLutInsts : Predicate<"Subtarget->hasTensorCvtLutInsts()">, + AssemblerPredicate<(all_of FeatureTensorCvtLutInsts)>; + def HasTransposeLoadF4F6Insts : Predicate<"Subtarget->hasTransposeLoadF4F6Insts()">, AssemblerPredicate<(all_of FeatureTransposeLoadF4F6Insts)>; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td index 992572f..394a143 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td @@ -51,18 +51,6 @@ def gi_vop3pmodsdot : GIComplexOperandMatcher<s32, "selectVOP3PModsDOT">, GIComplexPatternEquiv<VOP3PModsDOT>; -def gi_vop3pmodsneg : - GIComplexOperandMatcher<s32, "selectVOP3PModsNeg">, - GIComplexPatternEquiv<VOP3PModsNeg>; - -def gi_vop3pmodsnegs : - GIComplexOperandMatcher<s32, "selectVOP3PModsNegs">, - GIComplexPatternEquiv<VOP3PModsNegs>; - -def gi_dotiuvop3pmodsnegabs : - GIComplexOperandMatcher<s32, "selectVOP3PModsNegAbs">, - GIComplexPatternEquiv<VOP3PModsNegAbs>; - def gi_wmmaopselvop3pmods : GIComplexOperandMatcher<s32, "selectWMMAOpSelVOP3PMods">, GIComplexPatternEquiv<WMMAOpSelVOP3PMods>; @@ -452,6 +440,13 @@ def gi_fp_pow2_to_exponent : GICustomOperandRenderer<"renderFPPow2ToExponent">, def gi_as_hw_round_mode : GICustomOperandRenderer<"renderRoundMode">, GISDNodeXFormEquiv<as_hw_round_mode>; +def gi_VOP3PModsNeg : GICustomOperandRenderer<"renderVOP3PModsNeg">, + GISDNodeXFormEquiv<VOP3PModsNeg>; +def gi_VOP3PModsNegs : GICustomOperandRenderer<"renderVOP3PModsNegs">, + GISDNodeXFormEquiv<VOP3PModsNegs>; +def gi_VOP3PModsNegAbs : GICustomOperandRenderer<"renderVOP3PModsNegAbs">, + GISDNodeXFormEquiv<VOP3PModsNegAbs>; + def gi_prefetch_loc : GICustomOperandRenderer<"renderPrefetchLoc">, GISDNodeXFormEquiv<PrefetchLoc>; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp index 39b4200..fb83388 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp @@ -3449,63 +3449,6 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PModsDOT(SDValue In, SDValue &Src, return SelectVOP3PMods(In, Src, SrcMods, true); } -// Select neg_lo from the i1 immediate operand. -bool AMDGPUDAGToDAGISel::SelectVOP3PModsNeg(SDValue In, SDValue &Src) const { - const ConstantSDNode *C = cast<ConstantSDNode>(In); - // Literal i1 value set in intrinsic, represents SrcMods for the next operand. - // 1 promotes packed values to signed, 0 treats them as unsigned. - assert(C->getAPIntValue().getBitWidth() == 1 && "expected i1 value"); - - unsigned Mods = SISrcMods::OP_SEL_1; - unsigned SrcSign = C->getZExtValue(); - if (SrcSign == 1) - Mods ^= SISrcMods::NEG; - - Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); - return true; -} - -// Select both neg_lo and neg_hi from the i1 immediate operand. This is -// specifically for F16/BF16 operands in WMMA instructions, where neg_lo applies -// to matrix's even k elements, and neg_hi applies to matrix's odd k elements. -bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegs(SDValue In, SDValue &Src) const { - const ConstantSDNode *C = cast<ConstantSDNode>(In); - // Literal i1 value set in intrinsic, represents SrcMods for the next operand. - // 1 promotes packed values to signed, 0 treats them as unsigned. - assert(C->getAPIntValue().getBitWidth() == 1 && "expected i1 value"); - - unsigned Mods = SISrcMods::OP_SEL_1; - unsigned SrcSign = C->getZExtValue(); - if (SrcSign == 1) - Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI); - - Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); - return true; -} - -// Select neg, abs, or both neg and abs from the i16 immediate operans. -bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const { - const ConstantSDNode *C = cast<ConstantSDNode>(In); - unsigned Mods = SISrcMods::OP_SEL_1; - unsigned SrcMod = C->getZExtValue(); - switch (SrcMod) { - default: // Any other value will be silently ignored (considered as 0). - break; - case 1: - Mods ^= SISrcMods::NEG; - break; - case 2: - Mods ^= SISrcMods::ABS; - break; - case 3: - Mods ^= (SISrcMods::NEG | SISrcMods::ABS); - break; - } - - Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); - return true; -} - bool AMDGPUDAGToDAGISel::SelectWMMAOpSelVOP3PMods(SDValue In, SDValue &Src) const { const ConstantSDNode *C = cast<ConstantSDNode>(In); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h index 983f1aa..16388e7 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h @@ -241,9 +241,6 @@ private: bool IsDOT = false) const; bool SelectVOP3PModsDOT(SDValue In, SDValue &Src, SDValue &SrcMods) const; - bool SelectVOP3PModsNeg(SDValue In, SDValue &Src) const; - bool SelectVOP3PModsNegs(SDValue In, SDValue &Src) const; - bool SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const; bool SelectWMMAOpSelVOP3PMods(SDValue In, SDValue &Src) const; bool SelectWMMAModsF32NegAbs(SDValue In, SDValue &Src, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 31c4f62..64e68ab 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -589,14 +589,6 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, setSchedulingPreference(Sched::RegPressure); setJumpIsExpensive(true); - // FIXME: This is only partially true. If we have to do vector compares, any - // SGPR pair can be a condition register. If we have a uniform condition, we - // are better off doing SALU operations, where there is only one SCC. For now, - // we don't have a way of knowing during instruction selection if a condition - // will be uniform and we always use vector compares. Assume we are using - // vector compares until that is fixed. - setHasMultipleConditionRegisters(true); - setMinCmpXchgSizeInBits(32); setSupportsUnalignedAtomics(false); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h index 39bb0ad..fd5d5b8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h @@ -388,6 +388,16 @@ public: MVT getFenceOperandTy(const DataLayout &DL) const override { return MVT::i32; } + + bool hasMultipleConditionRegisters(EVT VT) const override { + // FIXME: This is only partially true. If we have to do vector compares, any + // SGPR pair can be a condition register. If we have a uniform condition, we + // are better off doing SALU operations, where there is only one SCC. For + // now, we don't have a way of knowing during instruction selection if a + // condition will be uniform and we always use vector compares. Assume we + // are using vector compares until that is fixed. + return true; + } }; namespace AMDGPUISD { diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp index f2207ff..4fe5d00 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp @@ -1694,7 +1694,9 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { NewII->takeName(&II); return IC.replaceInstUsesWith(II, NewII); } - case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: { + case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: + case Intrinsic::amdgcn_wmma_scale_f32_16x16x128_f8f6f4: + case Intrinsic::amdgcn_wmma_scale16_f32_16x16x128_f8f6f4: { Value *Src0 = II.getArgOperand(1); Value *Src1 = II.getArgOperand(3); unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue(); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp index b0d3b12..b7fd131 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -4988,66 +4988,6 @@ AMDGPUInstructionSelector::selectVOP3PModsDOT(MachineOperand &Root) const { return selectVOP3PRetHelper(Root, true); } -// Select neg_lo from the i1 immediate operand. -InstructionSelector::ComplexRendererFns -AMDGPUInstructionSelector::selectVOP3PModsNeg(MachineOperand &Root) const { - // Literal i1 value set in intrinsic, represents SrcMods for the next operand. - // Value is in Imm operand as i1 sign extended to int64_t. - // 1(-1) promotes packed values to signed, 0 treats them as unsigned. - assert((Root.isImm() && (Root.getImm() == -1 || Root.getImm() == 0)) && - "expected i1 value"); - unsigned Mods = SISrcMods::OP_SEL_1; - if (Root.getImm() == -1) - Mods ^= SISrcMods::NEG; - return {{ - [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods - }}; -} - -// Select both neg_lo and neg_hi from the i1 immediate operand. This is -// specifically for F16/BF16 operands in WMMA instructions, where neg_lo applies -// to matrix's even k elements, and neg_hi applies to matrix's odd k elements. -InstructionSelector::ComplexRendererFns -AMDGPUInstructionSelector::selectVOP3PModsNegs(MachineOperand &Root) const { - // Literal i1 value set in intrinsic, represents SrcMods for the next operand. - // Value is in Imm operand as i1 sign extended to int64_t. - // 1(-1) promotes packed values to signed, 0 treats them as unsigned. - assert((Root.isImm() && (Root.getImm() == -1 || Root.getImm() == 0)) && - "expected i1 value"); - unsigned Mods = SISrcMods::OP_SEL_1; - if (Root.getImm() == -1) - Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI); - return {{ - [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods - }}; -} - -// Select neg, abs, or both neg and abs from the i16 immediate operans. -InstructionSelector::ComplexRendererFns -AMDGPUInstructionSelector::selectVOP3PModsNegAbs(MachineOperand &Root) const { - - assert(Root.isImm() && "Modifier for C must be an immediate"); - - unsigned Mods = SISrcMods::OP_SEL_1; - switch (Root.getImm()) { - default: // Any other value will be silently ignored (considered as 0). - break; - case 1: - Mods ^= SISrcMods::NEG; - break; - case 2: - Mods ^= SISrcMods::ABS; - break; - case 3: - Mods ^= (SISrcMods::NEG | SISrcMods::ABS); - break; - } - - return {{ - [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods - }}; -} - InstructionSelector::ComplexRendererFns AMDGPUInstructionSelector::selectWMMAOpSelVOP3PMods( MachineOperand &Root) const { @@ -7102,6 +7042,38 @@ void AMDGPUInstructionSelector::renderRoundMode(MachineInstrBuilder &MIB, MIB.addImm((MI.getOperand(OpIdx).getImm() + 3) % 4); } +void AMDGPUInstructionSelector::renderVOP3PModsNeg(MachineInstrBuilder &MIB, + const MachineInstr &MI, + int OpIdx) const { + unsigned Mods = SISrcMods::OP_SEL_1; + if (MI.getOperand(OpIdx).getImm()) + Mods ^= SISrcMods::NEG; + MIB.addImm((int64_t)Mods); +} + +void AMDGPUInstructionSelector::renderVOP3PModsNegs(MachineInstrBuilder &MIB, + const MachineInstr &MI, + int OpIdx) const { + unsigned Mods = SISrcMods::OP_SEL_1; + if (MI.getOperand(OpIdx).getImm()) + Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI); + MIB.addImm((int64_t)Mods); +} + +void AMDGPUInstructionSelector::renderVOP3PModsNegAbs(MachineInstrBuilder &MIB, + const MachineInstr &MI, + int OpIdx) const { + unsigned Val = MI.getOperand(OpIdx).getImm(); + unsigned Mods = SISrcMods::OP_SEL_1; // default: none + if (Val == 1) // neg + Mods ^= SISrcMods::NEG; + if (Val == 2) // abs + Mods ^= SISrcMods::ABS; + if (Val == 3) // neg and abs + Mods ^= (SISrcMods::NEG | SISrcMods::ABS); + MIB.addImm((int64_t)Mods); +} + void AMDGPUInstructionSelector::renderPrefetchLoc(MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const { diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h index 140e753..c9da419 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h @@ -200,13 +200,6 @@ private: selectVOP3PModsDOT(MachineOperand &Root) const; InstructionSelector::ComplexRendererFns - selectVOP3PModsNeg(MachineOperand &Root) const; - InstructionSelector::ComplexRendererFns - selectVOP3PModsNegs(MachineOperand &Root) const; - InstructionSelector::ComplexRendererFns - selectVOP3PModsNegAbs(MachineOperand &Root) const; - - InstructionSelector::ComplexRendererFns selectWMMAOpSelVOP3PMods(MachineOperand &Root) const; InstructionSelector::ComplexRendererFns @@ -419,6 +412,13 @@ private: void renderRoundMode(MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const; + void renderVOP3PModsNeg(MachineInstrBuilder &MIB, const MachineInstr &MI, + int OpIdx) const; + void renderVOP3PModsNegs(MachineInstrBuilder &MIB, const MachineInstr &MI, + int OpIdx) const; + void renderVOP3PModsNegAbs(MachineInstrBuilder &MIB, const MachineInstr &MI, + int OpIdx) const; + void renderPrefetchLoc(MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const; diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp index 1fdf272..a6e4a63 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -2271,6 +2271,9 @@ Register AMDGPULegalizerInfo::getSegmentAperture( const unsigned ApertureRegNo = (AS == AMDGPUAS::LOCAL_ADDRESS) ? AMDGPU::SRC_SHARED_BASE : AMDGPU::SRC_PRIVATE_BASE; + assert((ApertureRegNo != AMDGPU::SRC_PRIVATE_BASE || + !ST.hasGloballyAddressableScratch()) && + "Cannot use src_private_base with globally addressable scratch!"); // FIXME: It would be more natural to emit a COPY here, but then copy // coalescing would kick in and it would think it's okay to use the "HI" // subregister (instead of extracting the HI 32 bits) which is an artificial @@ -2396,11 +2399,30 @@ bool AMDGPULegalizerInfo::legalizeAddrSpaceCast( if (SrcAS == AMDGPUAS::FLAT_ADDRESS && (DestAS == AMDGPUAS::LOCAL_ADDRESS || DestAS == AMDGPUAS::PRIVATE_ADDRESS)) { + auto castFlatToLocalOrPrivate = [&](const DstOp &Dst) -> Register { + if (DestAS == AMDGPUAS::PRIVATE_ADDRESS && + ST.hasGloballyAddressableScratch()) { + // flat -> private with globally addressable scratch: subtract + // src_flat_scratch_base_lo. + const LLT S32 = LLT::scalar(32); + Register SrcLo = B.buildExtract(S32, Src, 0).getReg(0); + Register FlatScratchBaseLo = + B.buildInstr(AMDGPU::S_MOV_B32, {S32}, + {Register(AMDGPU::SRC_FLAT_SCRATCH_BASE_LO)}) + .getReg(0); + MRI.setRegClass(FlatScratchBaseLo, &AMDGPU::SReg_32RegClass); + Register Sub = B.buildSub(S32, SrcLo, FlatScratchBaseLo).getReg(0); + return B.buildIntToPtr(Dst, Sub).getReg(0); + } + + // Extract low 32-bits of the pointer. + return B.buildExtract(Dst, Src, 0).getReg(0); + }; + // For llvm.amdgcn.addrspacecast.nonnull we can always assume non-null, for // G_ADDRSPACE_CAST we need to guess. if (isa<GIntrinsic>(MI) || isKnownNonNull(Src, MRI, TM, SrcAS)) { - // Extract low 32-bits of the pointer. - B.buildExtract(Dst, Src, 0); + castFlatToLocalOrPrivate(Dst); MI.eraseFromParent(); return true; } @@ -2411,7 +2433,7 @@ bool AMDGPULegalizerInfo::legalizeAddrSpaceCast( auto FlatNull = B.buildConstant(SrcTy, 0); // Extract low 32-bits of the pointer. - auto PtrLo32 = B.buildExtract(DstTy, Src, 0); + auto PtrLo32 = castFlatToLocalOrPrivate(DstTy); auto CmpRes = B.buildICmp(CmpInst::ICMP_NE, LLT::scalar(1), Src, FlatNull.getReg(0)); @@ -2425,14 +2447,45 @@ bool AMDGPULegalizerInfo::legalizeAddrSpaceCast( (SrcAS == AMDGPUAS::LOCAL_ADDRESS || SrcAS == AMDGPUAS::PRIVATE_ADDRESS)) { auto castLocalOrPrivateToFlat = [&](const DstOp &Dst) -> Register { - Register ApertureReg = getSegmentAperture(SrcAS, MRI, B); - if (!ApertureReg.isValid()) - return false; - // Coerce the type of the low half of the result so we can use // merge_values. Register SrcAsInt = B.buildPtrToInt(S32, Src).getReg(0); + if (SrcAS == AMDGPUAS::PRIVATE_ADDRESS && + ST.hasGloballyAddressableScratch()) { + // For wave32: Addr = (TID[4:0] << 52) + FLAT_SCRATCH_BASE + privateAddr + // For wave64: Addr = (TID[5:0] << 51) + FLAT_SCRATCH_BASE + privateAddr + Register AllOnes = B.buildConstant(S32, -1).getReg(0); + Register ThreadID = B.buildConstant(S32, 0).getReg(0); + ThreadID = B.buildIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {S32}) + .addUse(AllOnes) + .addUse(ThreadID) + .getReg(0); + if (ST.isWave64()) { + ThreadID = B.buildIntrinsic(Intrinsic::amdgcn_mbcnt_hi, {S32}) + .addUse(AllOnes) + .addUse(ThreadID) + .getReg(0); + } + Register ShAmt = + B.buildConstant(S32, 57 - 32 - ST.getWavefrontSizeLog2()).getReg(0); + Register SrcHi = B.buildShl(S32, ThreadID, ShAmt).getReg(0); + Register CvtPtr = + B.buildMergeLikeInstr(DstTy, {SrcAsInt, SrcHi}).getReg(0); + // Accessing src_flat_scratch_base_lo as a 64-bit operand gives the full + // 64-bit hi:lo value. + Register FlatScratchBase = + B.buildInstr(AMDGPU::S_MOV_B64, {S64}, + {Register(AMDGPU::SRC_FLAT_SCRATCH_BASE)}) + .getReg(0); + MRI.setRegClass(FlatScratchBase, &AMDGPU::SReg_64RegClass); + return B.buildPtrAdd(Dst, CvtPtr, FlatScratchBase).getReg(0); + } + + Register ApertureReg = getSegmentAperture(SrcAS, MRI, B); + if (!ApertureReg.isValid()) + return false; + // TODO: Should we allow mismatched types but matching sizes in merges to // avoid the ptrtoint? return B.buildMergeLikeInstr(Dst, {SrcAsInt, ApertureReg}).getReg(0); @@ -5788,11 +5841,25 @@ bool AMDGPULegalizerInfo::legalizeIsAddrSpace(MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, unsigned AddrSpace) const { - Register ApertureReg = getSegmentAperture(AddrSpace, MRI, B); - auto Unmerge = B.buildUnmerge(LLT::scalar(32), MI.getOperand(2).getReg()); + const LLT S32 = LLT::scalar(32); + auto Unmerge = B.buildUnmerge(S32, MI.getOperand(2).getReg()); Register Hi32 = Unmerge.getReg(1); - B.buildICmp(ICmpInst::ICMP_EQ, MI.getOperand(0), Hi32, ApertureReg); + if (AddrSpace == AMDGPUAS::PRIVATE_ADDRESS && + ST.hasGloballyAddressableScratch()) { + Register FlatScratchBaseHi = + B.buildInstr(AMDGPU::S_MOV_B32, {S32}, + {Register(AMDGPU::SRC_FLAT_SCRATCH_BASE_HI)}) + .getReg(0); + MRI.setRegClass(FlatScratchBaseHi, &AMDGPU::SReg_32RegClass); + // Test bits 63..58 against the aperture address. + Register XOR = B.buildXor(S32, Hi32, FlatScratchBaseHi).getReg(0); + B.buildICmp(ICmpInst::ICMP_ULT, MI.getOperand(0), XOR, + B.buildConstant(S32, 1u << 26)); + } else { + Register ApertureReg = getSegmentAperture(AddrSpace, MRI, B); + B.buildICmp(ICmpInst::ICMP_EQ, MI.getOperand(0), Hi32, ApertureReg); + } MI.eraseFromParent(); return true; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerModuleLDSPass.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerModuleLDSPass.cpp index d443f4e..2d8f259 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULowerModuleLDSPass.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULowerModuleLDSPass.cpp @@ -236,7 +236,7 @@ cl::opt<LoweringKind> LoweringKindLoc( "Lower via mixture of above strategies"))); template <typename T> std::vector<T> sortByName(std::vector<T> &&V) { - llvm::sort(V.begin(), V.end(), [](const auto *L, const auto *R) { + llvm::sort(V, [](const auto *L, const auto *R) { return L->getName() < R->getName(); }); return {std::move(V)}; diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp index 5aa0ebf..868b1a2 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp @@ -4603,6 +4603,42 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_cvt_scale_pk8_f32_fp8: case Intrinsic::amdgcn_cvt_scale_pk8_f32_bf8: case Intrinsic::amdgcn_cvt_scale_pk8_f32_fp4: + case Intrinsic::amdgcn_cvt_scale_pk16_f16_fp6: + case Intrinsic::amdgcn_cvt_scale_pk16_bf16_fp6: + case Intrinsic::amdgcn_cvt_scale_pk16_f16_bf6: + case Intrinsic::amdgcn_cvt_scale_pk16_bf16_bf6: + case Intrinsic::amdgcn_cvt_scale_pk16_f32_fp6: + case Intrinsic::amdgcn_cvt_scale_pk16_f32_bf6: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp8_bf16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_bf8_bf16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp8_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_bf8_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp8_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk8_bf8_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp4_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp4_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk8_fp4_bf16: + case Intrinsic::amdgcn_cvt_scalef32_pk16_fp6_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk16_bf6_f32: + case Intrinsic::amdgcn_cvt_scalef32_pk16_fp6_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk16_bf6_f16: + case Intrinsic::amdgcn_cvt_scalef32_pk16_fp6_bf16: + case Intrinsic::amdgcn_cvt_scalef32_pk16_bf6_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp8_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_bf8_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp8_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_bf8_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp8_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_bf8_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp4_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp4_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk8_fp4_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_fp6_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_bf6_f32: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_fp6_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_bf6_f16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_fp6_bf16: + case Intrinsic::amdgcn_cvt_scalef32_sr_pk16_bf6_bf16: case Intrinsic::amdgcn_sat_pk4_i4_i8: case Intrinsic::amdgcn_sat_pk4_u4_u8: case Intrinsic::amdgcn_fmed3: @@ -4762,7 +4798,11 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8: case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8: case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: + case Intrinsic::amdgcn_wmma_scale_f32_16x16x128_f8f6f4: + case Intrinsic::amdgcn_wmma_scale16_f32_16x16x128_f8f6f4: case Intrinsic::amdgcn_wmma_f32_32x16x128_f4: + case Intrinsic::amdgcn_wmma_scale_f32_32x16x128_f4: + case Intrinsic::amdgcn_wmma_scale16_f32_32x16x128_f4: case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16: case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16: case Intrinsic::amdgcn_swmmac_f32_16x16x64_bf16: @@ -4777,6 +4817,9 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_swmmac_f16_16x16x128_bf8_fp8: case Intrinsic::amdgcn_swmmac_f16_16x16x128_bf8_bf8: case Intrinsic::amdgcn_swmmac_i32_16x16x128_iu8: + case Intrinsic::amdgcn_perm_pk16_b4_u4: + case Intrinsic::amdgcn_perm_pk16_b6_u4: + case Intrinsic::amdgcn_perm_pk16_b8_u4: return getDefaultMappingVOP(MI); case Intrinsic::amdgcn_log: case Intrinsic::amdgcn_exp2: diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index a83caa0..ff8efd2 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -178,6 +178,10 @@ public: ImmTyBitOp3, ImmTyMatrixAFMT, ImmTyMatrixBFMT, + ImmTyMatrixAScale, + ImmTyMatrixBScale, + ImmTyMatrixAScaleFmt, + ImmTyMatrixBScaleFmt, ImmTyMatrixAReuse, ImmTyMatrixBReuse, ImmTyScaleSel, @@ -428,6 +432,10 @@ public: bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); } bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); } bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); } + bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); } + bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); } + bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); } + bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); } bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); } bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); } bool isTFE() const { return isImmTy(ImmTyTFE); } @@ -1183,6 +1191,10 @@ public: case ImmTyBitOp3: OS << "BitOp3"; break; case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break; case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break; + case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break; + case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break; + case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break; + case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break; case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break; case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break; case ImmTyScaleSel: OS << "ScaleSel" ; break; @@ -1608,6 +1620,10 @@ public: return getFeatureBits()[AMDGPU::FeaturePartialNSAEncoding]; } + bool hasGloballyAddressableScratch() const { + return getFeatureBits()[AMDGPU::FeatureGloballyAddressableScratch]; + } + unsigned getNSAMaxSize(bool HasSampler = false) const { return AMDGPU::getNSAMaxSize(getSTI(), HasSampler); } @@ -1728,6 +1744,14 @@ public: AMDGPUOperand::ImmTy Type); ParseStatus parseMatrixAFMT(OperandVector &Operands); ParseStatus parseMatrixBFMT(OperandVector &Operands); + ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name, + AMDGPUOperand::ImmTy Type); + ParseStatus parseMatrixAScale(OperandVector &Operands); + ParseStatus parseMatrixBScale(OperandVector &Operands); + ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name, + AMDGPUOperand::ImmTy Type); + ParseStatus parseMatrixAScaleFmt(OperandVector &Operands); + ParseStatus parseMatrixBScaleFmt(OperandVector &Operands); ParseStatus parseDfmtNfmt(int64_t &Format); ParseStatus parseUfmt(int64_t &Format); @@ -2739,46 +2763,48 @@ static int getRegClass(RegisterKind Is, unsigned RegWidth) { static MCRegister getSpecialRegForName(StringRef RegName) { return StringSwitch<unsigned>(RegName) - .Case("exec", AMDGPU::EXEC) - .Case("vcc", AMDGPU::VCC) - .Case("flat_scratch", AMDGPU::FLAT_SCR) - .Case("xnack_mask", AMDGPU::XNACK_MASK) - .Case("shared_base", AMDGPU::SRC_SHARED_BASE) - .Case("src_shared_base", AMDGPU::SRC_SHARED_BASE) - .Case("shared_limit", AMDGPU::SRC_SHARED_LIMIT) - .Case("src_shared_limit", AMDGPU::SRC_SHARED_LIMIT) - .Case("private_base", AMDGPU::SRC_PRIVATE_BASE) - .Case("src_private_base", AMDGPU::SRC_PRIVATE_BASE) - .Case("private_limit", AMDGPU::SRC_PRIVATE_LIMIT) - .Case("src_private_limit", AMDGPU::SRC_PRIVATE_LIMIT) - .Case("pops_exiting_wave_id", AMDGPU::SRC_POPS_EXITING_WAVE_ID) - .Case("src_pops_exiting_wave_id", AMDGPU::SRC_POPS_EXITING_WAVE_ID) - .Case("lds_direct", AMDGPU::LDS_DIRECT) - .Case("src_lds_direct", AMDGPU::LDS_DIRECT) - .Case("m0", AMDGPU::M0) - .Case("vccz", AMDGPU::SRC_VCCZ) - .Case("src_vccz", AMDGPU::SRC_VCCZ) - .Case("execz", AMDGPU::SRC_EXECZ) - .Case("src_execz", AMDGPU::SRC_EXECZ) - .Case("scc", AMDGPU::SRC_SCC) - .Case("src_scc", AMDGPU::SRC_SCC) - .Case("tba", AMDGPU::TBA) - .Case("tma", AMDGPU::TMA) - .Case("flat_scratch_lo", AMDGPU::FLAT_SCR_LO) - .Case("flat_scratch_hi", AMDGPU::FLAT_SCR_HI) - .Case("xnack_mask_lo", AMDGPU::XNACK_MASK_LO) - .Case("xnack_mask_hi", AMDGPU::XNACK_MASK_HI) - .Case("vcc_lo", AMDGPU::VCC_LO) - .Case("vcc_hi", AMDGPU::VCC_HI) - .Case("exec_lo", AMDGPU::EXEC_LO) - .Case("exec_hi", AMDGPU::EXEC_HI) - .Case("tma_lo", AMDGPU::TMA_LO) - .Case("tma_hi", AMDGPU::TMA_HI) - .Case("tba_lo", AMDGPU::TBA_LO) - .Case("tba_hi", AMDGPU::TBA_HI) - .Case("pc", AMDGPU::PC_REG) - .Case("null", AMDGPU::SGPR_NULL) - .Default(AMDGPU::NoRegister); + .Case("exec", AMDGPU::EXEC) + .Case("vcc", AMDGPU::VCC) + .Case("flat_scratch", AMDGPU::FLAT_SCR) + .Case("xnack_mask", AMDGPU::XNACK_MASK) + .Case("shared_base", AMDGPU::SRC_SHARED_BASE) + .Case("src_shared_base", AMDGPU::SRC_SHARED_BASE) + .Case("shared_limit", AMDGPU::SRC_SHARED_LIMIT) + .Case("src_shared_limit", AMDGPU::SRC_SHARED_LIMIT) + .Case("private_base", AMDGPU::SRC_PRIVATE_BASE) + .Case("src_private_base", AMDGPU::SRC_PRIVATE_BASE) + .Case("private_limit", AMDGPU::SRC_PRIVATE_LIMIT) + .Case("src_private_limit", AMDGPU::SRC_PRIVATE_LIMIT) + .Case("src_flat_scratch_base_lo", AMDGPU::SRC_FLAT_SCRATCH_BASE_LO) + .Case("src_flat_scratch_base_hi", AMDGPU::SRC_FLAT_SCRATCH_BASE_HI) + .Case("pops_exiting_wave_id", AMDGPU::SRC_POPS_EXITING_WAVE_ID) + .Case("src_pops_exiting_wave_id", AMDGPU::SRC_POPS_EXITING_WAVE_ID) + .Case("lds_direct", AMDGPU::LDS_DIRECT) + .Case("src_lds_direct", AMDGPU::LDS_DIRECT) + .Case("m0", AMDGPU::M0) + .Case("vccz", AMDGPU::SRC_VCCZ) + .Case("src_vccz", AMDGPU::SRC_VCCZ) + .Case("execz", AMDGPU::SRC_EXECZ) + .Case("src_execz", AMDGPU::SRC_EXECZ) + .Case("scc", AMDGPU::SRC_SCC) + .Case("src_scc", AMDGPU::SRC_SCC) + .Case("tba", AMDGPU::TBA) + .Case("tma", AMDGPU::TMA) + .Case("flat_scratch_lo", AMDGPU::FLAT_SCR_LO) + .Case("flat_scratch_hi", AMDGPU::FLAT_SCR_HI) + .Case("xnack_mask_lo", AMDGPU::XNACK_MASK_LO) + .Case("xnack_mask_hi", AMDGPU::XNACK_MASK_HI) + .Case("vcc_lo", AMDGPU::VCC_LO) + .Case("vcc_hi", AMDGPU::VCC_HI) + .Case("exec_lo", AMDGPU::EXEC_LO) + .Case("exec_hi", AMDGPU::EXEC_HI) + .Case("tma_lo", AMDGPU::TMA_LO) + .Case("tma_hi", AMDGPU::TMA_HI) + .Case("tba_lo", AMDGPU::TBA_LO) + .Case("tba_hi", AMDGPU::TBA_HI) + .Case("pc", AMDGPU::PC_REG) + .Case("null", AMDGPU::SGPR_NULL) + .Default(AMDGPU::NoRegister); } bool AMDGPUAsmParser::ParseRegister(MCRegister &RegNo, SMLoc &StartLoc, @@ -6724,6 +6750,9 @@ bool AMDGPUAsmParser::subtargetHasRegister(const MCRegisterInfo &MRI, case SRC_PRIVATE_LIMIT_LO: case SRC_PRIVATE_LIMIT: return isGFX9Plus(); + case SRC_FLAT_SCRATCH_BASE_LO: + case SRC_FLAT_SCRATCH_BASE_HI: + return hasGloballyAddressableScratch(); case SRC_POPS_EXITING_WAVE_ID: return isGFX9Plus() && !isGFX11Plus(); case TBA: @@ -7356,6 +7385,42 @@ ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) { AMDGPUOperand::ImmTyMatrixBFMT); } +ParseStatus AMDGPUAsmParser::tryParseMatrixScale(OperandVector &Operands, + StringRef Name, + AMDGPUOperand::ImmTy Type) { + return parseStringOrIntWithPrefix( + Operands, Name, {"MATRIX_SCALE_ROW0", "MATRIX_SCALE_ROW1"}, Type); +} + +ParseStatus AMDGPUAsmParser::parseMatrixAScale(OperandVector &Operands) { + return tryParseMatrixScale(Operands, "matrix_a_scale", + AMDGPUOperand::ImmTyMatrixAScale); +} + +ParseStatus AMDGPUAsmParser::parseMatrixBScale(OperandVector &Operands) { + return tryParseMatrixScale(Operands, "matrix_b_scale", + AMDGPUOperand::ImmTyMatrixBScale); +} + +ParseStatus AMDGPUAsmParser::tryParseMatrixScaleFmt(OperandVector &Operands, + StringRef Name, + AMDGPUOperand::ImmTy Type) { + return parseStringOrIntWithPrefix( + Operands, Name, + {"MATRIX_SCALE_FMT_E8", "MATRIX_SCALE_FMT_E5M3", "MATRIX_SCALE_FMT_E4M3"}, + Type); +} + +ParseStatus AMDGPUAsmParser::parseMatrixAScaleFmt(OperandVector &Operands) { + return tryParseMatrixScaleFmt(Operands, "matrix_a_scale_fmt", + AMDGPUOperand::ImmTyMatrixAScaleFmt); +} + +ParseStatus AMDGPUAsmParser::parseMatrixBScaleFmt(OperandVector &Operands) { + return tryParseMatrixScaleFmt(Operands, "matrix_b_scale_fmt", + AMDGPUOperand::ImmTyMatrixBScaleFmt); +} + // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their // values to live in a joint format operand in the MCInst encoding. ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) { @@ -9489,6 +9554,34 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands, AMDGPUOperand::ImmTyMatrixBFMT, 0); } + int MatrixAScaleIdx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale); + if (MatrixAScaleIdx != -1) { + addOptionalImmOperand(Inst, Operands, OptIdx, + AMDGPUOperand::ImmTyMatrixAScale, 0); + } + + int MatrixBScaleIdx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale); + if (MatrixBScaleIdx != -1) { + addOptionalImmOperand(Inst, Operands, OptIdx, + AMDGPUOperand::ImmTyMatrixBScale, 0); + } + + int MatrixAScaleFmtIdx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale_fmt); + if (MatrixAScaleFmtIdx != -1) { + addOptionalImmOperand(Inst, Operands, OptIdx, + AMDGPUOperand::ImmTyMatrixAScaleFmt, 0); + } + + int MatrixBScaleFmtIdx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale_fmt); + if (MatrixBScaleFmtIdx != -1) { + addOptionalImmOperand(Inst, Operands, OptIdx, + AMDGPUOperand::ImmTyMatrixBScaleFmt, 0); + } + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse)) addOptionalImmOperand(Inst, Operands, OptIdx, AMDGPUOperand::ImmTyMatrixAReuse, 0); diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp index ffe6b06..fb7d634 100644 --- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp +++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp @@ -598,6 +598,13 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size, // Try to decode DPP and SDWA first to solve conflict with VOP1 and VOP2 // encodings + if (isGFX1250() && Bytes.size() >= 16) { + DecoderUInt128 DecW = eat16Bytes(Bytes); + if (tryDecodeInst(DecoderTableGFX1250128, MI, DecW, Address, CS)) + break; + Bytes = Bytes_.slice(0, MaxInstBytesNum); + } + if (isGFX11Plus() && Bytes.size() >= 12 ) { DecoderUInt128 DecW = eat12Bytes(Bytes); @@ -1907,6 +1914,8 @@ MCOperand AMDGPUDisassembler::decodeSpecialReg32(unsigned Val) const { return isGFX11Plus() ? createRegOperand(M0) : createRegOperand(SGPR_NULL); case 126: return createRegOperand(EXEC_LO); case 127: return createRegOperand(EXEC_HI); + case 230: return createRegOperand(SRC_FLAT_SCRATCH_BASE_LO); + case 231: return createRegOperand(SRC_FLAT_SCRATCH_BASE_HI); case 235: return createRegOperand(SRC_SHARED_BASE_LO); case 236: return createRegOperand(SRC_SHARED_LIMIT_LO); case 237: return createRegOperand(SRC_PRIVATE_BASE_LO); @@ -1940,6 +1949,7 @@ MCOperand AMDGPUDisassembler::decodeSpecialReg64(unsigned Val) const { return createRegOperand(SGPR_NULL); break; case 126: return createRegOperand(EXEC); + case 230: return createRegOperand(SRC_FLAT_SCRATCH_BASE_LO); case 235: return createRegOperand(SRC_SHARED_BASE); case 236: return createRegOperand(SRC_SHARED_LIMIT); case 237: return createRegOperand(SRC_PRIVATE_BASE); diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h index 6fe3abc..5530886 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h @@ -236,6 +236,7 @@ protected: bool Has64BitLiterals = false; bool HasBitOp3Insts = false; bool HasTanhInsts = false; + bool HasTensorCvtLutInsts = false; bool HasTransposeLoadF4F6Insts = false; bool HasPrngInst = false; bool HasBVHDualAndBVH8Insts = false; @@ -280,6 +281,7 @@ protected: bool RequiresCOV6 = false; bool UseBlockVGPROpsForCSR = false; + bool HasGloballyAddressableScratch = false; // Dummy feature to use for assembler in tablegen. bool FeatureDisable = false; @@ -1324,6 +1326,10 @@ public: bool useVGPRBlockOpsForCSR() const { return UseBlockVGPROpsForCSR; } + bool hasGloballyAddressableScratch() const { + return HasGloballyAddressableScratch; + } + bool hasVALUMaskWriteHazard() const { return getGeneration() == GFX11; } bool hasVALUReadSGPRHazard() const { return GFX12Insts && !GFX1250Insts; } @@ -1411,6 +1417,8 @@ public: bool hasTanhInsts() const { return HasTanhInsts; } + bool hasTensorCvtLutInsts() const { return HasTensorCvtLutInsts; } + bool hasAddPC64Inst() const { return GFX1250Insts; } bool hasMinimum3Maximum3PKF16() const { diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp index 86d56855..4e4660c 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUAsmBackend.cpp @@ -33,8 +33,7 @@ public: AMDGPUAsmBackend(const Target &T) : MCAsmBackend(llvm::endianness::little) {} void applyFixup(const MCFragment &, const MCFixup &, const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) override; + uint8_t *Data, uint64_t Value, bool IsResolved) override; bool fixupNeedsRelaxation(const MCFixup &Fixup, uint64_t Value) const override; @@ -129,9 +128,8 @@ static uint64_t adjustFixupValue(const MCFixup &Fixup, uint64_t Value, } void AMDGPUAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, - const MCValue &Target, - MutableArrayRef<char> Data, uint64_t Value, - bool IsResolved) { + const MCValue &Target, uint8_t *Data, + uint64_t Value, bool IsResolved) { if (Target.getSpecifier()) IsResolved = false; maybeAddReloc(F, Fixup, Target, Value, IsResolved); @@ -148,13 +146,13 @@ void AMDGPUAsmBackend::applyFixup(const MCFragment &F, const MCFixup &Fixup, Value <<= Info.TargetOffset; unsigned NumBytes = getFixupKindNumBytes(Fixup.getKind()); - uint32_t Offset = Fixup.getOffset(); - assert(Offset + NumBytes <= F.getSize() && "Invalid fixup offset!"); + assert(Fixup.getOffset() + NumBytes <= F.getSize() && + "Invalid fixup offset!"); // For each byte of the fragment that the fixup touches, mask in the bits from // the fixup value. for (unsigned i = 0; i != NumBytes; ++i) - Data[Offset + i] |= static_cast<uint8_t>((Value >> (i * 8)) & 0xff); + Data[i] |= static_cast<uint8_t>((Value >> (i * 8)) & 0xff); } std::optional<MCFixupKind> diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp index 42c4d8b..ee8683a 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp @@ -1393,6 +1393,75 @@ void AMDGPUInstPrinter::printMatrixBFMT(const MCInst *MI, unsigned OpNo, printMatrixFMT(MI, OpNo, STI, O, 'b'); } +void AMDGPUInstPrinter::printMatrixScale(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O, char AorB) { + auto Imm = MI->getOperand(OpNo).getImm() & 1; + if (Imm == 0) + return; + + O << " matrix_" << AorB << "_scale:"; + switch (Imm) { + default: + O << Imm; + break; + case WMMA::MatrixScale::MATRIX_SCALE_ROW0: + O << "MATRIX_SCALE_ROW0"; + break; + case WMMA::MatrixScale::MATRIX_SCALE_ROW1: + O << "MATRIX_SCALE_ROW1"; + break; + } +} + +void AMDGPUInstPrinter::printMatrixAScale(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O) { + printMatrixScale(MI, OpNo, STI, O, 'a'); +} + +void AMDGPUInstPrinter::printMatrixBScale(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O) { + printMatrixScale(MI, OpNo, STI, O, 'b'); +} + +void AMDGPUInstPrinter::printMatrixScaleFmt(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O, char AorB) { + auto Imm = MI->getOperand(OpNo).getImm() & 3; + if (Imm == 0) + return; + + O << " matrix_" << AorB << "_scale_fmt:"; + switch (Imm) { + default: + O << Imm; + break; + case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E8: + O << "MATRIX_SCALE_FMT_E8"; + break; + case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E5M3: + O << "MATRIX_SCALE_FMT_E5M3"; + break; + case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E4M3: + O << "MATRIX_SCALE_FMT_E4M3"; + break; + } +} + +void AMDGPUInstPrinter::printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O) { + printMatrixScaleFmt(MI, OpNo, STI, O, 'a'); +} + +void AMDGPUInstPrinter::printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, + raw_ostream &O) { + printMatrixScaleFmt(MI, OpNo, STI, O, 'b'); +} + void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum, const MCSubtargetInfo &STI, raw_ostream &O) { diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h index f6739b14..be32061c 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h @@ -140,6 +140,19 @@ private: const MCSubtargetInfo &STI, raw_ostream &O); void printMatrixBFMT(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI, raw_ostream &O); + void printMatrixScale(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O, char AorB); + void printMatrixAScale(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O); + void printMatrixBScale(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O); + void printMatrixScaleFmt(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O, + char AorB); + void printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O); + void printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo, + const MCSubtargetInfo &STI, raw_ostream &O); void printInterpSlot(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI, raw_ostream &O); void printInterpAttr(const MCInst *MI, unsigned OpNo, diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp index ffdac8b..fa0c95f 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp @@ -75,8 +75,9 @@ unsigned AMDGPUMCAsmInfo::getMaxInstLength(const MCSubtargetInfo *STI) const { if (STI->hasFeature(AMDGPU::FeatureNSAEncoding)) return 20; - // VOP3PX encoding. - if (STI->hasFeature(AMDGPU::FeatureGFX950Insts)) + // VOP3PX/VOP3PX2 encoding. + if (STI->hasFeature(AMDGPU::FeatureGFX950Insts) || + STI->hasFeature(AMDGPU::FeatureGFX1250Insts)) return 16; // 64-bit instruction with 32-bit literal. diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.cpp index 43ca548..68302f0 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.cpp @@ -872,14 +872,14 @@ void AMDGPUTargetELFStreamer::EmitAMDKernelCodeT(AMDGPUMCKernelCodeT &Header) { void AMDGPUTargetELFStreamer::EmitAMDGPUSymbolType(StringRef SymbolName, unsigned Type) { - MCSymbolELF *Symbol = cast<MCSymbolELF>( + auto *Symbol = static_cast<MCSymbolELF *>( getStreamer().getContext().getOrCreateSymbol(SymbolName)); Symbol->setType(Type); } void AMDGPUTargetELFStreamer::emitAMDGPULDS(MCSymbol *Symbol, unsigned Size, Align Alignment) { - MCSymbolELF *SymbolELF = cast<MCSymbolELF>(Symbol); + auto *SymbolELF = static_cast<MCSymbolELF *>(Symbol); SymbolELF->setType(ELF::STT_OBJECT); if (!SymbolELF->isBindingSet()) @@ -974,9 +974,9 @@ void AMDGPUTargetELFStreamer::EmitAmdhsaKernelDescriptor( auto &Streamer = getStreamer(); auto &Context = Streamer.getContext(); - MCSymbolELF *KernelCodeSymbol = cast<MCSymbolELF>( - Context.getOrCreateSymbol(Twine(KernelName))); - MCSymbolELF *KernelDescriptorSymbol = cast<MCSymbolELF>( + auto *KernelCodeSymbol = + static_cast<MCSymbolELF *>(Context.getOrCreateSymbol(Twine(KernelName))); + auto *KernelDescriptorSymbol = static_cast<MCSymbolELF *>( Context.getOrCreateSymbol(Twine(KernelName) + Twine(".kd"))); // Copy kernel descriptor symbol's binding, other and visibility from the diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h index c564145..deadb7a 100644 --- a/llvm/lib/Target/AMDGPU/SIDefines.h +++ b/llvm/lib/Target/AMDGPU/SIDefines.h @@ -1018,6 +1018,17 @@ enum MatrixFMT : unsigned { MATRIX_FMT_BF6 = 3, MATRIX_FMT_FP4 = 4 }; + +enum MatrixScale : unsigned { + MATRIX_SCALE_ROW0 = 0, + MATRIX_SCALE_ROW1 = 1, +}; + +enum MatrixScaleFmt : unsigned { + MATRIX_SCALE_FMT_E8 = 0, + MATRIX_SCALE_FMT_E5M3 = 1, + MATRIX_SCALE_FMT_E4M3 = 2 +}; } // namespace WMMA namespace VOP3PEncoding { diff --git a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp index e934152..0c653b1 100644 --- a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp +++ b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp @@ -1169,11 +1169,18 @@ void SIFoldOperandsImpl::foldOperand( // Grab the use operands first SmallVector<MachineOperand *, 4> UsesToProcess( llvm::make_pointer_range(MRI->use_nodbg_operands(RegSeqDstReg))); - for (auto *RSUse : UsesToProcess) { + for (unsigned I = 0; I != UsesToProcess.size(); ++I) { + MachineOperand *RSUse = UsesToProcess[I]; MachineInstr *RSUseMI = RSUse->getParent(); unsigned OpNo = RSUseMI->getOperandNo(RSUse); if (SplatRC) { + if (RSUseMI->isCopy()) { + Register DstReg = RSUseMI->getOperand(0).getReg(); + append_range(UsesToProcess, + make_pointer_range(MRI->use_nodbg_operands(DstReg))); + continue; + } if (tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) { FoldableDef SplatDef(SplatVal, SplatRC); appendFoldCandidate(FoldList, RSUseMI, OpNo, SplatDef); diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 4d67e4a..63826b7 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -2098,10 +2098,17 @@ bool SITargetLowering::isNonGlobalAddrSpace(unsigned AS) { bool SITargetLowering::isFreeAddrSpaceCast(unsigned SrcAS, unsigned DestAS) const { - // Flat -> private/local is a simple truncate. - // Flat -> global is no-op - if (SrcAS == AMDGPUAS::FLAT_ADDRESS) + if (SrcAS == AMDGPUAS::FLAT_ADDRESS) { + if (DestAS == AMDGPUAS::PRIVATE_ADDRESS && + Subtarget->hasGloballyAddressableScratch()) { + // Flat -> private requires subtracting src_flat_scratch_base_lo. + return false; + } + + // Flat -> private/local is a simple truncate. + // Flat -> global is no-op return true; + } const GCNTargetMachine &TM = static_cast<const GCNTargetMachine &>(getTargetMachine()); @@ -7650,6 +7657,9 @@ SDValue SITargetLowering::getSegmentAperture(unsigned AS, const SDLoc &DL, const unsigned ApertureRegNo = (AS == AMDGPUAS::LOCAL_ADDRESS) ? AMDGPU::SRC_SHARED_BASE : AMDGPU::SRC_PRIVATE_BASE; + assert((ApertureRegNo != AMDGPU::SRC_PRIVATE_BASE || + !Subtarget->hasGloballyAddressableScratch()) && + "Cannot use src_private_base with globally addressable scratch!"); // Note: this feature (register) is broken. When used as a 32-bit operand, // it returns a wrong value (all zeroes?). The real value is in the upper 32 // bits. @@ -7760,6 +7770,18 @@ SDValue SITargetLowering::lowerADDRSPACECAST(SDValue Op, DestAS == AMDGPUAS::PRIVATE_ADDRESS) { SDValue Ptr = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Src); + if (DestAS == AMDGPUAS::PRIVATE_ADDRESS && + Subtarget->hasGloballyAddressableScratch()) { + // flat -> private with globally addressable scratch: subtract + // src_flat_scratch_base_lo. + SDValue FlatScratchBaseLo( + DAG.getMachineNode( + AMDGPU::S_MOV_B32, SL, MVT::i32, + DAG.getRegister(AMDGPU::SRC_FLAT_SCRATCH_BASE_LO, MVT::i32)), + 0); + Ptr = DAG.getNode(ISD::SUB, SL, MVT::i32, Ptr, FlatScratchBaseLo); + } + if (IsNonNull || isKnownNonNull(Op, DAG, TM, SrcAS)) return Ptr; @@ -7776,11 +7798,40 @@ SDValue SITargetLowering::lowerADDRSPACECAST(SDValue Op, if (DestAS == AMDGPUAS::FLAT_ADDRESS) { if (SrcAS == AMDGPUAS::LOCAL_ADDRESS || SrcAS == AMDGPUAS::PRIVATE_ADDRESS) { - - SDValue Aperture = getSegmentAperture(SrcAS, SL, DAG); - SDValue CvtPtr = - DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v2i32, Src, Aperture); - CvtPtr = DAG.getNode(ISD::BITCAST, SL, MVT::i64, CvtPtr); + SDValue CvtPtr; + if (SrcAS == AMDGPUAS::PRIVATE_ADDRESS && + Subtarget->hasGloballyAddressableScratch()) { + // For wave32: Addr = (TID[4:0] << 52) + FLAT_SCRATCH_BASE + privateAddr + // For wave64: Addr = (TID[5:0] << 51) + FLAT_SCRATCH_BASE + privateAddr + SDValue AllOnes = DAG.getSignedTargetConstant(-1, SL, MVT::i32); + SDValue ThreadID = DAG.getConstant(0, SL, MVT::i32); + ThreadID = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, + DAG.getTargetConstant(Intrinsic::amdgcn_mbcnt_lo, SL, MVT::i32), + AllOnes, ThreadID); + if (Subtarget->isWave64()) + ThreadID = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, + DAG.getTargetConstant(Intrinsic::amdgcn_mbcnt_hi, SL, MVT::i32), + AllOnes, ThreadID); + SDValue ShAmt = DAG.getShiftAmountConstant( + 57 - 32 - Subtarget->getWavefrontSizeLog2(), MVT::i32, SL); + SDValue SrcHi = DAG.getNode(ISD::SHL, SL, MVT::i32, ThreadID, ShAmt); + CvtPtr = DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v2i32, Src, SrcHi); + CvtPtr = DAG.getNode(ISD::BITCAST, SL, MVT::i64, CvtPtr); + // Accessing src_flat_scratch_base_lo as a 64-bit operand gives the full + // 64-bit hi:lo value. + SDValue FlatScratchBase = { + DAG.getMachineNode( + AMDGPU::S_MOV_B64, SL, MVT::i64, + DAG.getRegister(AMDGPU::SRC_FLAT_SCRATCH_BASE, MVT::i64)), + 0}; + CvtPtr = DAG.getNode(ISD::ADD, SL, MVT::i64, CvtPtr, FlatScratchBase); + } else { + SDValue Aperture = getSegmentAperture(SrcAS, SL, DAG); + CvtPtr = DAG.getNode(ISD::BUILD_VECTOR, SL, MVT::v2i32, Src, Aperture); + CvtPtr = DAG.getNode(ISD::BITCAST, SL, MVT::i64, CvtPtr); + } if (IsNonNull || isKnownNonNull(Op, DAG, TM, SrcAS)) return CvtPtr; @@ -9424,15 +9475,29 @@ SDValue SITargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::amdgcn_is_shared: case Intrinsic::amdgcn_is_private: { SDLoc SL(Op); - unsigned AS = (IntrinsicID == Intrinsic::amdgcn_is_shared) - ? AMDGPUAS::LOCAL_ADDRESS - : AMDGPUAS::PRIVATE_ADDRESS; - SDValue Aperture = getSegmentAperture(AS, SL, DAG); SDValue SrcVec = DAG.getNode(ISD::BITCAST, DL, MVT::v2i32, Op.getOperand(1)); - SDValue SrcHi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, SrcVec, DAG.getConstant(1, SL, MVT::i32)); + + unsigned AS = (IntrinsicID == Intrinsic::amdgcn_is_shared) + ? AMDGPUAS::LOCAL_ADDRESS + : AMDGPUAS::PRIVATE_ADDRESS; + if (AS == AMDGPUAS::PRIVATE_ADDRESS && + Subtarget->hasGloballyAddressableScratch()) { + SDValue FlatScratchBaseHi( + DAG.getMachineNode( + AMDGPU::S_MOV_B32, DL, MVT::i32, + DAG.getRegister(AMDGPU::SRC_FLAT_SCRATCH_BASE_HI, MVT::i32)), + 0); + // Test bits 63..58 against the aperture address. + return DAG.getSetCC( + SL, MVT::i1, + DAG.getNode(ISD::XOR, SL, MVT::i32, SrcHi, FlatScratchBaseHi), + DAG.getConstant(1u << 26, SL, MVT::i32), ISD::SETULT); + } + + SDValue Aperture = getSegmentAperture(AS, SL, DAG); return DAG.getSetCC(SL, MVT::i1, SrcHi, Aperture, ISD::SETEQ); } case Intrinsic::amdgcn_perm: diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td index a3e20ba..c552f1a 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td @@ -908,6 +908,32 @@ def SupportedRoundMode : TImmLeaf<i32, [{ Imm == (int)RoundingMode::TowardNegative; }]>; +def VOP3PModsNeg : SDNodeXForm<timm, [{ + unsigned Mods = SISrcMods::OP_SEL_1; + if (N->getZExtValue()) + Mods ^= SISrcMods::NEG; + return CurDAG->getTargetConstant(Mods, SDLoc(N), MVT::i32); +}]>; + +def VOP3PModsNegs : SDNodeXForm<timm, [{ + unsigned Mods = SISrcMods::OP_SEL_1; + if (N->getZExtValue()) + Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI); + return CurDAG->getTargetConstant(Mods, SDLoc(N), MVT::i32); +}]>; + +def VOP3PModsNegAbs : SDNodeXForm<timm, [{ + unsigned Val = N->getZExtValue(); + unsigned Mods = SISrcMods::OP_SEL_1; // default: none + if (Val == 1) // neg + Mods ^= SISrcMods::NEG; + if (Val == 2) // abs + Mods ^= SISrcMods::ABS; + if (Val == 3) // neg and abs + Mods ^= (SISrcMods::NEG | SISrcMods::ABS); + return CurDAG->getTargetConstant(Mods, SDLoc(N), MVT::i32); +}]>; + class bitextract_imm<int bitnum> : SDNodeXForm<imm, [{ uint64_t Imm = N->getZExtValue(); unsigned Bit = (Imm >> }] # bitnum # [{ ) & 1; @@ -1310,6 +1336,12 @@ def bitop3_0 : DefaultOperand<BitOp3, 0>; def MatrixAFMT : CustomOperand<i32, 1, "MatrixAFMT">; def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">; +def MatrixAScale : CustomOperand<i32, 1, "MatrixAScale">; +def MatrixBScale : CustomOperand<i32, 1, "MatrixBScale">; + +def MatrixAScaleFmt : CustomOperand<i32, 1, "MatrixAScaleFmt">; +def MatrixBScaleFmt : CustomOperand<i32, 1, "MatrixBScaleFmt">; + def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">; def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">; @@ -1647,9 +1679,6 @@ def VOP3OMods : ComplexPattern<untyped, 3, "SelectVOP3OMods">; def VOP3PMods : ComplexPattern<untyped, 2, "SelectVOP3PMods">; def VOP3PModsDOT : ComplexPattern<untyped, 2, "SelectVOP3PModsDOT">; -def VOP3PModsNeg : ComplexPattern<untyped, 1, "SelectVOP3PModsNeg">; -def VOP3PModsNegs : ComplexPattern<untyped, 1, "SelectVOP3PModsNegs">; // chfang: not use complex pattern? -def VOP3PModsNegAbs : ComplexPattern<untyped, 1, "SelectVOP3PModsNegAbs">; def WMMAOpSelVOP3PMods : ComplexPattern<untyped, 1, "SelectWMMAOpSelVOP3PMods">; def WMMAModsF32NegAbs : ComplexPattern<untyped, 2, "SelectWMMAModsF32NegAbs">; @@ -1774,6 +1803,7 @@ class getVALUDstForVT<ValueType VT, bit IsTrue16 = 0, bit IsVOP3Encoding = 0> { !eq(VT.Size, 256) : VOPDstOperand<VReg_256>, !eq(VT.Size, 192) : VOPDstOperand<VReg_192>, !eq(VT.Size, 128) : VOPDstOperand<VReg_128>, + !eq(VT.Size, 96) : VOPDstOperand<VReg_96>, !eq(VT.Size, 64) : VOPDstOperand<VReg_64>, !eq(VT.Size, 32) : VOPDstOperand<VGPR_32>, !eq(VT.Size, 16) : op16, @@ -1924,6 +1954,7 @@ class getVOP3DPPSrcForVT<ValueType VT, bit IsFake16 = 1> { !eq(VT, v2f16) : VCSrc_v2f16, !eq(VT, v2bf16) : VCSrc_v2bf16, !eq(VT, f32) : VCSrc_f32, + !eq(VT, v2i32) : VCSrc_v2b32, 1 : VCSrc_b32); } @@ -2678,6 +2709,8 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> { field bit HasNeg = HasModifiers; field bit HasMatrixReuse = 0; field bit HasMatrixFMT = 0; + field bit HasMatrixScale = 0; + field bit HasMatrixReuse = 0; field bit HasSrc0Mods = HasModifiers; field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0); @@ -2935,6 +2968,9 @@ def VOP_V2BF16_F32_F32_I32 : VOPProfile <[v2bf16, f32, f32, i32]>; def VOP_V2F16_F32_F32_I32 : VOPProfile <[v2f16, f32, f32, i32]>; def VOP_V6I32_V32F16_F32 : VOPProfile<[v6i32, v32f16, f32, untyped]>; def VOP_V6I32_V32BF16_F32 : VOPProfile<[v6i32, v32bf16, f32, untyped]>; +def VOP_V3I32_V16F16_F32 : VOPProfile<[v3i32, v16f16, f32, untyped]>; +def VOP_V3I32_V16BF16_F32 : VOPProfile<[v3i32, v16bf16, f32, untyped]>; +def VOP_V3I32_V16F32_F32 : VOPProfile<[v3i32, v16f32, f32, untyped]>; def VOP_V6I32_V16F32_V16F32_F32 : VOPProfile<[v6i32, v16f32, v16f32, f32]>; def VOP_V2F16_I32_F32 : VOPProfile<[v2f16, i32, f32, untyped]>; def VOP_V2I16_F32_F32_F32 : VOPProfile<[v2i16, f32, f32, f32]>; @@ -2948,6 +2984,8 @@ def VOP_BF16_F32_I32 : VOPProfile<[bf16, f32, i32, untyped]>; def VOP_F16_F32_I32 : VOPProfile<[f16, f32, i32, untyped]>; def VOP_I32_BF16_I32_F32 : VOPProfile<[i32, bf16, i32, f32]>; def VOP_I32_F16_I32_F32 : VOPProfile<[i32, f16, i32, f32]>; +def VOP_V16F16_V3I32_I32 : VOPProfile<[v16f16, v3i32, i32, untyped]>; +def VOP_V16BF16_V3I32_I32 : VOPProfile<[v16bf16, v3i32, i32, untyped]>; def VOP_V8F16_V2I32_I32 : VOPProfile<[v8f16, v2i32, i32, untyped]>; def VOP_V8BF16_V2I32_I32 : VOPProfile<[v8bf16, v2i32, i32, untyped]>; def VOP_V8F16_I32_I32 : VOPProfile<[v8f16, i32, i32, untyped]>; @@ -2955,11 +2993,26 @@ def VOP_V8BF16_I32_I32 : VOPProfile<[v8bf16, i32, i32, untyped]>; def VOP_V16F32_V3I32_I32 : VOPProfile<[v16f32, v3i32, i32, untyped]>; def VOP_V8F32_V2I32_I32 : VOPProfile<[v8f32, v2i32, i32, untyped]>; def VOP_V8F32_I32_I32 : VOPProfile<[v8f32, i32, i32, untyped]>; +def VOP_V2I32_V8BF16_F32 : VOPProfile<[v2i32, v8bf16, f32, untyped]>; +def VOP_V2I32_V8F16_F32 : VOPProfile<[v2i32, v8f16, f32, untyped]>; +def VOP_V2I32_V8F32_F32 : VOPProfile<[v2i32, v8f32, f32, untyped]>; +def VOP_I32_V8F32_F32 : VOPProfile<[i32, v8f32, f32, untyped]>; +def VOP_I32_V8F16_F32 : VOPProfile<[i32, v8f16, f32, untyped]>; +def VOP_I32_V8BF16_F32 : VOPProfile<[i32, v8bf16, f32, untyped]>; def VOP_I32_F32_I32_F32 : VOPProfile<[i32, f32, i32, f32]>; def VOP_V6I32_V32BF16_I32_F32 : VOPProfile<[v6i32, v32bf16, i32, f32]>; def VOP_V6I32_V32F16_I32_F32 : VOPProfile<[v6i32, v32f16, i32, f32]>; def VOP_V6I32_V32F32_I32_F32 : VOPProfile<[v6i32, v32f32, i32, f32]>; +def VOP_V3I32_V16F16_I32_F32 : VOPProfile<[v3i32, v16f16, i32, f32]>; +def VOP_V3I32_V16BF16_I32_F32 : VOPProfile<[v3i32, v16bf16, i32, f32]>; +def VOP_V3I32_V16F32_I32_F32 : VOPProfile<[v3i32, v16f32, i32, f32]>; +def VOP_V2I32_V8BF16_I32_F32 : VOPProfile<[v2i32, v8bf16, i32, f32]>; +def VOP_V2I32_V8F16_I32_F32 : VOPProfile<[v2i32, v8f16, i32, f32]>; +def VOP_V2I32_V8F32_I32_F32 : VOPProfile<[v2i32, v8f32, i32, f32]>; +def VOP_I32_V8F32_I32_F32 : VOPProfile<[i32, v8f32, i32, f32]>; +def VOP_I32_V8F16_I32_F32 : VOPProfile<[i32, v8f16, i32, f32]>; +def VOP_I32_V8BF16_I32_F32 : VOPProfile<[i32, v8bf16, i32, f32]>; def VOP_I64_I64_I32 : VOPProfile <[i64, i64, i32, untyped]>; def VOP_I64_I32_I64 : VOPProfile <[i64, i32, i64, untyped]>; diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td index 54fa192..bd5dfa9 100644 --- a/llvm/lib/Target/AMDGPU/SIInstructions.td +++ b/llvm/lib/Target/AMDGPU/SIInstructions.td @@ -3543,14 +3543,21 @@ def : GCNPat < (vecTy (UniformBinFrag<build_vector> (Ty undef), (Ty SReg_32:$src1))), (S_LSHL_B32 SReg_32:$src1, (i32 16)) >; -} def : GCNPat < (vecTy (DivergentBinFrag<build_vector> (Ty undef), (Ty VGPR_32:$src1))), (vecTy (V_LSHLREV_B32_e64 (i32 16), VGPR_32:$src1)) >; +} // End True16Predicate = ... } // End foreach Ty = ... -} +} // End AddedComplexity = 1 + +let True16Predicate = UseRealTrue16Insts in +def : GCNPat < + (v2i16 (DivergentBinFrag<build_vector> (i16 undef), (i16 (trunc i32:$src1)))), + (REG_SEQUENCE VGPR_32, (i16 (IMPLICIT_DEF)), lo16, + (i16 (EXTRACT_SUBREG VGPR_32:$src1, lo16)), hi16) +>; let SubtargetPredicate = HasVOP3PInsts in { foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in @@ -3599,7 +3606,11 @@ def : GCNPat < >; def : GCNPat < (vecTy (DivergentBinFrag<build_vector> (Ty VGPR_16:$src0), (Ty undef))), - (REG_SEQUENCE VGPR_32, $src0, lo16, (IMPLICIT_DEF), hi16) + (REG_SEQUENCE VGPR_32, $src0, lo16, (Ty (IMPLICIT_DEF)), hi16) +>; +def : GCNPat < + (vecTy (DivergentBinFrag<build_vector> (Ty undef), (Ty VGPR_16:$src1))), + (REG_SEQUENCE VGPR_32, (Ty (IMPLICIT_DEF)), lo16, (Ty VGPR_16:$src1), hi16) >; } diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp index f3acc5c..ae0f304 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp @@ -598,6 +598,8 @@ BitVector SIRegisterInfo::getReservedRegs(const MachineFunction &MF) const { reserveRegisterTuples(Reserved, AMDGPU::SRC_SHARED_LIMIT); reserveRegisterTuples(Reserved, AMDGPU::SRC_PRIVATE_BASE); reserveRegisterTuples(Reserved, AMDGPU::SRC_PRIVATE_LIMIT); + reserveRegisterTuples(Reserved, AMDGPU::SRC_FLAT_SCRATCH_BASE_LO); + reserveRegisterTuples(Reserved, AMDGPU::SRC_FLAT_SCRATCH_BASE_HI); // Reserve async counters pseudo registers reserveRegisterTuples(Reserved, AMDGPU::ASYNCcnt); diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td index 36d1a3b..81655f5 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td @@ -246,6 +246,22 @@ defm SRC_SHARED_LIMIT : ApertureRegister<"src_shared_limit", 236>; defm SRC_PRIVATE_BASE : ApertureRegister<"src_private_base", 237>; defm SRC_PRIVATE_LIMIT : ApertureRegister<"src_private_limit", 238>; +let isConstant = true in { + defm SRC_FLAT_SCRATCH_BASE_LO : SIRegLoHi16<"src_flat_scratch_base_lo", 230>; + defm SRC_FLAT_SCRATCH_BASE_HI : SIRegLoHi16<"src_flat_scratch_base_hi", 231>; + + // Using src_flat_scratch_base_lo in a 64-bit context gets the full 64-bit + // hi:lo value. + def SRC_FLAT_SCRATCH_BASE : + RegisterWithSubRegs<"src_flat_scratch_base_lo", + [SRC_FLAT_SCRATCH_BASE_LO, + SRC_FLAT_SCRATCH_BASE_HI]> { + let Namespace = "AMDGPU"; + let SubRegIndices = [sub0, sub1]; + let HWEncoding = SRC_FLAT_SCRATCH_BASE_LO.HWEncoding; + } +} + defm SRC_POPS_EXITING_WAVE_ID : SIRegLoHi16<"src_pops_exiting_wave_id", 239>; // Not addressable @@ -765,7 +781,7 @@ def SReg_32_XM0_XEXEC : SIRegisterClass<"AMDGPU", [i32, f32, i16, f16, bf16, v2i SGPR_NULL, SGPR_NULL_HI, TTMP_32, TMA_LO, TMA_HI, TBA_LO, TBA_HI, SRC_SHARED_BASE_LO, SRC_SHARED_LIMIT_LO, SRC_PRIVATE_BASE_LO, SRC_PRIVATE_LIMIT_LO, SRC_SHARED_BASE_HI, SRC_SHARED_LIMIT_HI, SRC_PRIVATE_BASE_HI, SRC_PRIVATE_LIMIT_HI, SRC_POPS_EXITING_WAVE_ID, - SRC_VCCZ, SRC_EXECZ, SRC_SCC)> { + SRC_VCCZ, SRC_EXECZ, SRC_SCC, SRC_FLAT_SCRATCH_BASE_LO, SRC_FLAT_SCRATCH_BASE_HI)> { let AllocationPriority = 0; } @@ -776,7 +792,8 @@ def SReg_LO16 : SIRegisterClass<"AMDGPU", [i16, f16, bf16], 16, SRC_SHARED_LIMIT_LO_LO16, SRC_PRIVATE_BASE_LO_LO16, SRC_PRIVATE_LIMIT_LO_LO16, SRC_SHARED_BASE_HI_LO16, SRC_SHARED_LIMIT_HI_LO16, SRC_PRIVATE_BASE_HI_LO16, SRC_PRIVATE_LIMIT_HI_LO16, SRC_POPS_EXITING_WAVE_ID_LO16, SRC_VCCZ_LO16, - SRC_EXECZ_LO16, SRC_SCC_LO16, EXEC_LO_LO16, EXEC_HI_LO16, M0_CLASS_LO16)> { + SRC_EXECZ_LO16, SRC_SCC_LO16, EXEC_LO_LO16, EXEC_HI_LO16, M0_CLASS_LO16, + SRC_FLAT_SCRATCH_BASE_LO_LO16, SRC_FLAT_SCRATCH_BASE_HI_LO16)> { let Size = 16; let isAllocatable = 0; let BaseClassOrder = 16; @@ -849,7 +866,8 @@ def TTMP_64 : SIRegisterClass<"AMDGPU", [v2i32, i64, f64, v4i16, v4f16, v4bf16], def SReg_64_XEXEC_XNULL : SIRegisterClass<"AMDGPU", [v2i32, i64, v2f32, f64, i1, v4i16, v4f16, v4bf16], 32, (add SGPR_64, VCC, FLAT_SCR, XNACK_MASK, SRC_SHARED_BASE, - SRC_SHARED_LIMIT, SRC_PRIVATE_BASE, SRC_PRIVATE_LIMIT, TTMP_64, TBA, TMA)> { + SRC_SHARED_LIMIT, SRC_PRIVATE_BASE, SRC_PRIVATE_LIMIT, TTMP_64, TBA, TMA, + SRC_FLAT_SCRATCH_BASE)> { let CopyCost = 1; let AllocationPriority = 1; let HasSGPR = 1; @@ -1302,6 +1320,7 @@ def VCSrc_f64 : SrcRegOrImm9 <VS_64, "OPERAND_REG_INLINE_C_FP64">; def VCSrc_v2b16 : SrcRegOrImm9 <VS_32, "OPERAND_REG_INLINE_C_V2INT16">; def VCSrc_v2bf16: SrcRegOrImm9 <VS_32, "OPERAND_REG_INLINE_C_V2BF16">; def VCSrc_v2f16 : SrcRegOrImm9 <VS_32, "OPERAND_REG_INLINE_C_V2FP16">; +def VCSrc_v2b32 : SrcRegOrImm9 <VS_64, "OPERAND_REG_INLINE_C_V2INT32">; // True 16 Operands def VCSrcT_b16 : SrcRegOrImm9_t16 <"OPERAND_REG_INLINE_C_INT16">; diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp index 65fa088..00dcb9b 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp @@ -2654,6 +2654,8 @@ bool isInlineValue(unsigned Reg) { case AMDGPU::SRC_PRIVATE_BASE: case AMDGPU::SRC_PRIVATE_LIMIT_LO: case AMDGPU::SRC_PRIVATE_LIMIT: + case AMDGPU::SRC_FLAT_SCRATCH_BASE_LO: + case AMDGPU::SRC_FLAT_SCRATCH_BASE_HI: case AMDGPU::SRC_POPS_EXITING_WAVE_ID: return true; case AMDGPU::SRC_VCCZ: diff --git a/llvm/lib/Target/AMDGPU/VOP1Instructions.td b/llvm/lib/Target/AMDGPU/VOP1Instructions.td index f621f85..b128207 100644 --- a/llvm/lib/Target/AMDGPU/VOP1Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP1Instructions.td @@ -107,18 +107,6 @@ class VOP1_DPP_Pseudo <string OpName, VOPProfile P, list<dag> pattern=[]> : VOP_DPP_Pseudo <OpName, P, pattern> { } -class getVOP1Pat <SDPatternOperator node, VOPProfile P> : LetDummies { - list<dag> ret = - !if(P.HasModifiers, - [(set P.DstVT:$vdst, (node (P.Src0VT (VOP3Mods P.Src0VT:$src0, i32:$src0_modifiers))))], - !if(P.HasOMod, - [(set P.DstVT:$vdst, (node (P.Src0VT (VOP3OMods P.Src0VT:$src0, - i1:$clamp, i32:$omod))))], - [(set P.DstVT:$vdst, (node (P.Src0VT P.Src0RC32:$src0)))] - ) - ); -} - multiclass VOP1Inst <string opName, VOPProfile P, SDPatternOperator node = null_frag, int VOPDOp = -1> { // We only want to set this on the basic, non-SDWA or DPP forms. diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 19ce7f5..f4b6af6 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -1726,6 +1726,12 @@ multiclass VOP3CvtScaleSelInst<string OpName, VOPProfile P, SDPatternOperator no } } +let HasExtVOP3DPP = 0, HasModifiers = 0 in { +def VOP3_V2I32_I32_I32_V2I32 : VOP3_Profile<VOPProfile<[v2i32, i32, i32, v2i32]>>; +def VOP3_V3I32_I32_I64_V2I32 : VOP3_Profile<VOPProfile<[v3i32, i32, i64, v2i32]>>; +def VOP3_V4I32_I64_I64_V2I32 : VOP3_Profile<VOPProfile<[v4i32, i64, i64, v2i32]>>; +} + let Src0RC64 = VSrc_NoInline_v2f16 in { def VOP3_CVT_PK_F8_F16_Profile : VOP3_Profile<VOP_I16_V2F16>; def VOP3_CVT_PK_F8_F16_True16_Profile : VOP3_Profile_True16<VOP3_CVT_PK_F8_F16_Profile>; @@ -1771,6 +1777,12 @@ let SubtargetPredicate = isGFX1250Plus in { defm V_CVT_SCALE_PK8_BF16_BF8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_bf16_bf8", VOP_V8BF16_V2I32_I32, int_amdgcn_cvt_scale_pk8_bf16_bf8>; defm V_CVT_SCALE_PK8_F32_FP8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_fp8", VOP_V8F32_V2I32_I32, int_amdgcn_cvt_scale_pk8_f32_fp8>; defm V_CVT_SCALE_PK8_F32_BF8 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_bf8", VOP_V8F32_V2I32_I32, int_amdgcn_cvt_scale_pk8_f32_bf8>; + defm V_CVT_SCALE_PK16_F16_FP6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_f16_fp6", VOP_V16F16_V3I32_I32, int_amdgcn_cvt_scale_pk16_f16_fp6>; + defm V_CVT_SCALE_PK16_BF16_FP6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_bf16_fp6", VOP_V16BF16_V3I32_I32, int_amdgcn_cvt_scale_pk16_bf16_fp6>; + defm V_CVT_SCALE_PK16_F16_BF6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_f16_bf6", VOP_V16F16_V3I32_I32, int_amdgcn_cvt_scale_pk16_f16_bf6>; + defm V_CVT_SCALE_PK16_BF16_BF6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_bf16_bf6", VOP_V16BF16_V3I32_I32, int_amdgcn_cvt_scale_pk16_bf16_bf6>; + defm V_CVT_SCALE_PK16_F32_FP6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_f32_fp6", VOP_V16F32_V3I32_I32, int_amdgcn_cvt_scale_pk16_f32_fp6>; + defm V_CVT_SCALE_PK16_F32_BF6 : VOP3CvtScaleSelInst<"v_cvt_scale_pk16_f32_bf6", VOP_V16F32_V3I32_I32, int_amdgcn_cvt_scale_pk16_f32_bf6>; } // End Constraints = "@earlyclobber $vdst" defm V_CVT_SCALE_PK8_F16_FP4 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f16_fp4", VOP_V8F16_I32_I32, int_amdgcn_cvt_scale_pk8_f16_fp4>; @@ -1778,6 +1790,44 @@ let SubtargetPredicate = isGFX1250Plus in { defm V_CVT_SCALE_PK8_F32_FP4 : VOP3CvtScaleSelInst<"v_cvt_scale_pk8_f32_fp4", VOP_V8F32_I32_I32, int_amdgcn_cvt_scale_pk8_f32_fp4>; } // End ReadsModeReg = 0 + let Constraints = "@earlyclobber $vdst" in { + let WaveSizePredicate = isWave32 in { + defm V_CVT_SCALEF32_PK8_FP8_BF16 : VOP3Inst<"v_cvt_scalef32_pk8_fp8_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8BF16_F32>, int_amdgcn_cvt_scalef32_pk8_fp8_bf16>; + defm V_CVT_SCALEF32_PK8_BF8_BF16 : VOP3Inst<"v_cvt_scalef32_pk8_bf8_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8BF16_F32>, int_amdgcn_cvt_scalef32_pk8_bf8_bf16>; + defm V_CVT_SCALEF32_PK8_FP8_F16 : VOP3Inst<"v_cvt_scalef32_pk8_fp8_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F16_F32>, int_amdgcn_cvt_scalef32_pk8_fp8_f16>; + defm V_CVT_SCALEF32_PK8_BF8_F16 : VOP3Inst<"v_cvt_scalef32_pk8_bf8_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F16_F32>, int_amdgcn_cvt_scalef32_pk8_bf8_f16>; + defm V_CVT_SCALEF32_PK8_FP8_F32 : VOP3Inst<"v_cvt_scalef32_pk8_fp8_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F32_F32>, int_amdgcn_cvt_scalef32_pk8_fp8_f32>; + defm V_CVT_SCALEF32_PK8_BF8_F32 : VOP3Inst<"v_cvt_scalef32_pk8_bf8_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F32_F32>, int_amdgcn_cvt_scalef32_pk8_bf8_f32>; + defm V_CVT_SCALEF32_PK8_FP4_F32 : VOP3Inst<"v_cvt_scalef32_pk8_fp4_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8F32_F32>, int_amdgcn_cvt_scalef32_pk8_fp4_f32>; + defm V_CVT_SCALEF32_PK8_FP4_F16 : VOP3Inst<"v_cvt_scalef32_pk8_fp4_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8F16_F32>, int_amdgcn_cvt_scalef32_pk8_fp4_f16>; + defm V_CVT_SCALEF32_PK8_FP4_BF16 : VOP3Inst<"v_cvt_scalef32_pk8_fp4_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8BF16_F32>, int_amdgcn_cvt_scalef32_pk8_fp4_bf16>; + } // End WaveSizePredicate = isWave32 + defm V_CVT_SCALEF32_PK16_FP6_F32 : VOP3Inst<"v_cvt_scalef32_pk16_fp6_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F32_F32>, int_amdgcn_cvt_scalef32_pk16_fp6_f32>; + defm V_CVT_SCALEF32_PK16_BF6_F32 : VOP3Inst<"v_cvt_scalef32_pk16_bf6_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F32_F32>, int_amdgcn_cvt_scalef32_pk16_bf6_f32>; + defm V_CVT_SCALEF32_PK16_FP6_F16 : VOP3Inst<"v_cvt_scalef32_pk16_fp6_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F16_F32>, int_amdgcn_cvt_scalef32_pk16_fp6_f16>; + defm V_CVT_SCALEF32_PK16_BF6_F16 : VOP3Inst<"v_cvt_scalef32_pk16_bf6_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F16_F32>, int_amdgcn_cvt_scalef32_pk16_bf6_f16>; + defm V_CVT_SCALEF32_PK16_FP6_BF16 : VOP3Inst<"v_cvt_scalef32_pk16_fp6_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16BF16_F32>, int_amdgcn_cvt_scalef32_pk16_fp6_bf16>; + defm V_CVT_SCALEF32_PK16_BF6_BF16 : VOP3Inst<"v_cvt_scalef32_pk16_bf6_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16BF16_F32>, int_amdgcn_cvt_scalef32_pk16_bf6_bf16>; + + let WaveSizePredicate = isWave32 in { + defm V_CVT_SCALEF32_SR_PK8_FP8_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp8_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp8_bf16>; + defm V_CVT_SCALEF32_SR_PK8_BF8_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_bf8_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_bf8_bf16>; + defm V_CVT_SCALEF32_SR_PK8_FP8_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp8_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp8_f16>; + defm V_CVT_SCALEF32_SR_PK8_BF8_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_bf8_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_bf8_f16>; + defm V_CVT_SCALEF32_SR_PK8_FP8_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp8_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp8_f32>; + defm V_CVT_SCALEF32_SR_PK8_BF8_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk8_bf8_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V2I32_V8F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_bf8_f32>; + defm V_CVT_SCALEF32_SR_PK8_FP4_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp4_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp4_f32>; + defm V_CVT_SCALEF32_SR_PK8_FP4_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp4_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp4_f16>; + defm V_CVT_SCALEF32_SR_PK8_FP4_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk8_fp4_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_I32_V8BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk8_fp4_bf16>; + } // End WaveSizePredicate = isWave32 + defm V_CVT_SCALEF32_SR_PK16_BF6_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk16_bf6_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_bf6_bf16>; + defm V_CVT_SCALEF32_SR_PK16_BF6_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk16_bf6_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_bf6_f16>; + defm V_CVT_SCALEF32_SR_PK16_BF6_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk16_bf6_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_bf6_f32>; + defm V_CVT_SCALEF32_SR_PK16_FP6_BF16 : VOP3Inst<"v_cvt_scalef32_sr_pk16_fp6_bf16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16BF16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_fp6_bf16>; + defm V_CVT_SCALEF32_SR_PK16_FP6_F16 : VOP3Inst<"v_cvt_scalef32_sr_pk16_fp6_f16", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F16_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_fp6_f16>; + defm V_CVT_SCALEF32_SR_PK16_FP6_F32 : VOP3Inst<"v_cvt_scalef32_sr_pk16_fp6_f32", VOP3_CVT_SCALEF32_PK_F864_Profile<VOP_V3I32_V16F32_I32_F32>, int_amdgcn_cvt_scalef32_sr_pk16_fp6_f32>; + } // End Constraints = "@earlyclobber $vdst" + let True16Predicate = UseRealTrue16Insts in { def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_fp8_f16, V_CVT_SR_FP8_F16_t16_e64, f16>; def : Cvt_SR_F8_ByteSel_Pat<int_amdgcn_cvt_sr_bf8_f16, V_CVT_SR_BF8_F16_t16_e64, f16>; @@ -1788,6 +1838,12 @@ let SubtargetPredicate = isGFX1250Plus in { } } // End SubtargetPredicate = isGFX1250Plus +let SubtargetPredicate = HasTensorCvtLutInsts in { + defm V_PERM_PK16_B4_U4 : VOP3Inst<"v_perm_pk16_b4_u4", VOP3_V2I32_I32_I32_V2I32, int_amdgcn_perm_pk16_b4_u4>; + defm V_PERM_PK16_B6_U4 : VOP3Inst<"v_perm_pk16_b6_u4", VOP3_V3I32_I32_I64_V2I32, int_amdgcn_perm_pk16_b6_u4>; + defm V_PERM_PK16_B8_U4 : VOP3Inst<"v_perm_pk16_b8_u4", VOP3_V4I32_I64_I64_V2I32, int_amdgcn_perm_pk16_b8_u4>; +} // End SubtargetPredicate = HasTensorCvtLutInsts + class Cvt_Scale_Sr_F32ToBF16F16_Pat<SDPatternOperator node, VOP3_Pseudo inst, ValueType DstTy> : GCNPat< (DstTy (node DstTy:$vdst_in, f32:$src0, i32:$src1, timm:$word_sel)), (inst (DstSelToOpSelXForm $word_sel), $src0, 0, $src1, VGPR_32:$vdst_in) @@ -2186,6 +2242,9 @@ let AssemblerPredicate = isGFX11Plus in { } // These instructions differ from GFX12 variant by supporting DPP: +defm V_PERM_PK16_B4_U4 : VOP3Only_Real_Base_gfx1250<0x23f>; +defm V_PERM_PK16_B6_U4 : VOP3Only_Real_Base_gfx1250<0x242>; +defm V_PERM_PK16_B8_U4 : VOP3Only_Real_Base_gfx1250<0x243>; defm V_LSHL_ADD_U64 : VOP3Only_Realtriple_gfx1250<0x252>; defm V_ASHR_PK_I8_I32 : VOP3Only_Realtriple_gfx1250<0x290>; defm V_ASHR_PK_U8_I32 : VOP3Only_Realtriple_gfx1250<0x291>; @@ -2198,6 +2257,42 @@ defm V_CVT_SCALE_PK8_F32_FP8 : VOP3Only_ScaleSel_Real_gfx1250<0x2aa>; defm V_CVT_SCALE_PK8_F16_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ab>; defm V_CVT_SCALE_PK8_BF16_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ac>; defm V_CVT_SCALE_PK8_F32_BF8 : VOP3Only_ScaleSel_Real_gfx1250<0x2ad>; +defm V_CVT_SCALEF32_PK8_FP4_F32 : VOP3Only_Real_Base_gfx1250<0x2b0>; +defm V_CVT_SCALEF32_PK8_FP4_F16 : VOP3Only_Real_Base_gfx1250<0x2b3>; +defm V_CVT_SCALEF32_PK8_FP8_BF16 : VOP3Only_Real_Base_gfx1250<0x2b4>; +defm V_CVT_SCALEF32_PK8_BF8_BF16 : VOP3Only_Real_Base_gfx1250<0x2b5>; +defm V_CVT_SCALEF32_PK8_FP4_BF16 : VOP3Only_Real_Base_gfx1250<0x2b8>; +defm V_CVT_SCALEF32_PK8_FP8_F32 : VOP3Only_Real_Base_gfx1250<0x2c3>; +defm V_CVT_SCALEF32_PK8_FP8_F16 : VOP3Only_Real_Base_gfx1250<0x2c4>; +defm V_CVT_SCALEF32_PK8_BF8_F32 : VOP3Only_Real_Base_gfx1250<0x2c5>; +defm V_CVT_SCALEF32_PK8_BF8_F16 : VOP3Only_Real_Base_gfx1250<0x2c6>; +defm V_CVT_SCALE_PK16_F16_FP6 : VOP3Only_ScaleSel_Real_gfx1250<0x2c7>; +defm V_CVT_SCALE_PK16_BF16_FP6 : VOP3Only_ScaleSel_Real_gfx1250<0x2c8>; +defm V_CVT_SCALE_PK16_F32_FP6 : VOP3Only_ScaleSel_Real_gfx1250<0x2c9>; +defm V_CVT_SCALE_PK16_F16_BF6 : VOP3Only_ScaleSel_Real_gfx1250<0x2ca>; +defm V_CVT_SCALE_PK16_BF16_BF6 : VOP3Only_ScaleSel_Real_gfx1250<0x2cb>; +defm V_CVT_SCALE_PK16_F32_BF6 : VOP3Only_ScaleSel_Real_gfx1250<0x2cc>; +defm V_CVT_SCALEF32_PK16_FP6_F32 : VOP3Only_Real_Base_gfx1250<0x2cd>; +defm V_CVT_SCALEF32_PK16_BF6_F32 : VOP3Only_Real_Base_gfx1250<0x2ce>; +defm V_CVT_SCALEF32_PK16_FP6_F16 : VOP3Only_Real_Base_gfx1250<0x2cf>; +defm V_CVT_SCALEF32_PK16_BF6_F16 : VOP3Only_Real_Base_gfx1250<0x2d0>; +defm V_CVT_SCALEF32_PK16_FP6_BF16 : VOP3Only_Real_Base_gfx1250<0x2d1>; +defm V_CVT_SCALEF32_PK16_BF6_BF16 : VOP3Only_Real_Base_gfx1250<0x2d2>; +defm V_CVT_SCALEF32_SR_PK16_FP6_F32 : VOP3Only_Real_Base_gfx1250<0x2d3>; +defm V_CVT_SCALEF32_SR_PK16_BF6_F32 : VOP3Only_Real_Base_gfx1250<0x2d4>; +defm V_CVT_SCALEF32_SR_PK16_FP6_F16 : VOP3Only_Real_Base_gfx1250<0x2d5>; +defm V_CVT_SCALEF32_SR_PK16_BF6_F16 : VOP3Only_Real_Base_gfx1250<0x2d6>; +defm V_CVT_SCALEF32_SR_PK16_FP6_BF16 : VOP3Only_Real_Base_gfx1250<0x2d7>; +defm V_CVT_SCALEF32_SR_PK16_BF6_BF16 : VOP3Only_Real_Base_gfx1250<0x2d8>; +defm V_CVT_SCALEF32_SR_PK8_FP4_F32 : VOP3Only_Real_Base_gfx1250<0x297>; +defm V_CVT_SCALEF32_SR_PK8_FP8_F32 : VOP3Only_Real_Base_gfx1250<0x298>; +defm V_CVT_SCALEF32_SR_PK8_BF8_F32 : VOP3Only_Real_Base_gfx1250<0x299>; +defm V_CVT_SCALEF32_SR_PK8_FP4_F16 : VOP3Only_Real_Base_gfx1250<0x2b9>; +defm V_CVT_SCALEF32_SR_PK8_FP4_BF16 : VOP3Only_Real_Base_gfx1250<0x2bc>; +defm V_CVT_SCALEF32_SR_PK8_FP8_F16 : VOP3Only_Real_Base_gfx1250<0x2bf>; +defm V_CVT_SCALEF32_SR_PK8_FP8_BF16 : VOP3Only_Real_Base_gfx1250<0x2c0>; +defm V_CVT_SCALEF32_SR_PK8_BF8_F16 : VOP3Only_Real_Base_gfx1250<0x2c1>; +defm V_CVT_SCALEF32_SR_PK8_BF8_BF16 : VOP3Only_Real_Base_gfx1250<0x2c2>; defm V_CVT_PK_BF16_F32 : VOP3Only_Realtriple_gfx1250<0x36d>; defm V_CVT_SR_PK_BF16_F32 : VOP3Only_Realtriple_gfx1250<0x36e>; defm V_CVT_PK_F16_F32 : VOP3Only_Realtriple_gfx1250<0x36f>; diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index 95fcd4a..ce280d4 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -557,11 +557,11 @@ multiclass VOP3PDOTIUInst <string OpName, SDPatternOperator intrinsic_node> { null_frag, 1>; // Dot-iu instructions consider input as signed if imod neg bits are set. Thus // Dot-iu Intrinsics have extra operands and require separate codegen pattern. - def : GCNPat < (intrinsic_node (VOP3PModsNeg i32:$src0_mods), i32:$src0, - (VOP3PModsNeg i32:$src1_mods), i32:$src1, + def : GCNPat < (intrinsic_node timm:$src0_mods, i32:$src0, + timm:$src1_mods, i32:$src1, i32:$src2, (i1 timm:$clamp)), - (!cast<Instruction>(NAME) $src0_mods, i32:$src0, - $src1_mods, i32:$src1, + (!cast<Instruction>(NAME) (VOP3PModsNeg $src0_mods), i32:$src0, + (VOP3PModsNeg $src1_mods), i32:$src1, (i32 8), i32:$src2, i1:$clamp) >; } @@ -1302,11 +1302,11 @@ class WMMAOpSelPat<Instruction Inst, SDPatternOperator node, VOPProfile P> : class WMMAUIClampPat<Instruction Inst, SDPatternOperator node, VOPProfile P> : GCNPat < (P.DstVT (node - (VOP3PModsNeg i32:$src0_modifiers), (P.Src0VT P.Src0VT:$src0), - (VOP3PModsNeg i32:$src1_modifiers), (P.Src1VT P.Src1VT:$src1), + timm:$src0_modifiers, (P.Src0VT P.Src0VT:$src0), + timm:$src1_modifiers, (P.Src1VT P.Src1VT:$src1), (P.Src2VT P.Src2VT:$src2), (i1 timm:$clamp) )), - (P.DstVT (Inst i32:$src0_modifiers, P.Src0VT:$src0, i32:$src1_modifiers, P.Src1VT:$src1, (i32 8), P.Src2VT:$src2, i1:$clamp)) + (P.DstVT (Inst (VOP3PModsNeg $src0_modifiers), P.Src0VT:$src0, (VOP3PModsNeg $src1_modifiers), P.Src1VT:$src1, (i32 8), P.Src2VT:$src2, i1:$clamp)) >; class WMMAOpcodeMapping<Instruction TwoAddr, Instruction ThreeAddr> { @@ -1407,9 +1407,9 @@ let WaveSizePredicate = isWave64 in { } class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, - bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0, - bit _HasMatrixFMT = 0, bit _HasMatrixReuse = 0, - bit _IsF4 = 0> + bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0, + bit _HasMatrixFMT = 0, bit _HasMatrixScale = 0, + bit _Scale16 = 0, bit _HasMatrixReuse = 0, bit _IsF4 = 0> : VOP3P_Profile<VOPProfile<ArgTy>> { bit IsIU = _IsIU; bit NoABMods = !or(_IsFP8BF8XF32, _IsF4); // No IMOD support for A and B @@ -1417,6 +1417,8 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, int IndexType = _IndexType; let HasMatrixFMT = _HasMatrixFMT; + let HasMatrixScale = _HasMatrixScale; + bit Scale16 = _Scale16; let HasMatrixReuse = _HasMatrixReuse; bit HasIModOp = _Has_ImodOp; @@ -1455,6 +1457,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, IsC_F16: "_f16", IsC_BF16: "_bf16", 1: "_b32"))); + ValueType ScaleTy = !if(Scale16, i64, i32); // For f16 and bf16 matrices A and B, each element can be modified by // fneg(neg_lo,neg_hi = 1). For f32 and f64, neg_lo[0:1] is allowed, but @@ -1516,6 +1519,13 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, !eq(IndexType, 32): (ins IndexKey32bit:$index_key_32bit)); dag MatrixFMT = !if(HasMatrixFMT, (ins MatrixAFMT:$matrix_a_fmt, MatrixBFMT:$matrix_b_fmt), (ins)); + dag MatrixScaleSrc = !if(HasMatrixScale, + !if(Scale16, (ins VCSrc_b64:$scale_src0, VCSrc_b64:$scale_src1), + (ins VCSrc_b32:$scale_src0, VCSrc_b32:$scale_src1)), + (ins)); + dag MatrixScale = !if(HasMatrixScale, (ins MatrixAScale:$matrix_a_scale, MatrixBScale:$matrix_b_scale, + MatrixAScaleFmt:$matrix_a_scale_fmt, MatrixBScaleFmt:$matrix_b_scale_fmt), + (ins)); dag MatrixReuse = !if(HasMatrixReuse, (ins MatrixAReuse:$matrix_a_reuse, MatrixBReuse:$matrix_b_reuse), (ins)); dag Clamp = !if(HasClamp, (ins Clamp0:$clamp), (ins)); dag Neg = !cond(!and(NegLoAny, NegHiAny) : (ins neg_lo0:$neg_lo, neg_hi0:$neg_hi), @@ -1529,7 +1539,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, (ins VRegSrc_64:$src2), (ins VRegSrc_32:$src2)), IndexKey)), - MatrixFMT, MatrixReuse, Clamp, Neg); + MatrixScaleSrc, MatrixFMT, MatrixScale, MatrixReuse, Clamp, Neg); // asm @@ -1538,57 +1548,59 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, !eq(IndexType, 16) : "$index_key_16bit", !eq(IndexType, 32) : "$index_key_32bit"); string MatrxFMTAsm = !if(HasMatrixFMT, "$matrix_a_fmt$matrix_b_fmt", ""); + string MatrixScaleSrcAsm = !if(HasMatrixScale, ", $scale_src0, $scale_src1", ""); + string MatrixScaleAsm = !if(HasMatrixScale, "$matrix_a_scale$matrix_b_scale$matrix_a_scale_fmt$matrix_b_scale_fmt", ""); string MatrixReuseAsm = !if(HasMatrixReuse, "$matrix_a_reuse$matrix_b_reuse", ""); string ClampAsm = !if(HasClamp, "$clamp", ""); string NegAsm = !cond(!and(NegLoAny, NegHiAny) : "$neg_lo$neg_hi", !and(NegLoAny, !not(NegHiAny)) : "$neg_lo", !and(!not(NegLoAny), !not(NegHiAny)) : ""); - let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrxFMTAsm#MatrixReuseAsm#NegAsm#ClampAsm; + let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrixScaleSrcAsm#MatrxFMTAsm#MatrixScaleAsm#MatrixReuseAsm#NegAsm#ClampAsm; // isel patterns bit IsAB_BF16_IMod0 = !and(IsAB_BF16, !not(HasIModOp)); bit IsAB_F16_IMod0 = !and(IsAB_F16, !not(HasIModOp)); bit IsAB_F32F64_IMod1 = !and(!or(IsAB_F64, IsAB_F32), HasIModOp); bit IsAB_F16BF16_IMod1 = !and(!or(IsAB_F16, IsAB_BF16), HasIModOp); - dag Src0InPat = !cond(IsAB_F32F64_IMod1 : (ins (VOP3PModsNeg i32:$src0_modifiers), Src0VT:$src0), - IsAB_F16BF16_IMod1 : (ins (VOP3PModsNegs i32:$src0_modifiers), Src0VT:$src0), + dag Src0InPat = !cond(IsAB_F32F64_IMod1 : (ins timm:$src0_modifiers, Src0VT:$src0), + IsAB_F16BF16_IMod1 : (ins timm:$src0_modifiers, Src0VT:$src0), IsAB_F16_IMod0 : (ins (Src0VT (WMMAModsF16Neg Src0VT:$src0, i32:$src0_modifiers))), IsAB_BF16_IMod0 : (ins Src0VT:$src0), - IsIU : (ins (VOP3PModsNeg i32:$src0_modifiers), Src0VT:$src0), + IsIU : (ins timm:$src0_modifiers, Src0VT:$src0), HasMatrixFMT : (ins timm:$matrix_a_fmt, Src0VT:$src0), NoABMods : (ins Src0VT:$src0)); - dag Src0OutPat = !cond(IsAB_F32F64_IMod1 : (ins i32:$src0_modifiers, Src0VT:$src0), - IsAB_F16BF16_IMod1 : (ins i32:$src0_modifiers, Src0VT:$src0), + dag Src0OutPat = !cond(IsAB_F32F64_IMod1 : (ins (VOP3PModsNeg $src0_modifiers), Src0VT:$src0), + IsAB_F16BF16_IMod1 : (ins (VOP3PModsNegs $src0_modifiers), Src0VT:$src0), IsAB_F16_IMod0 : (ins i32:$src0_modifiers, Src0VT:$src0), IsAB_BF16_IMod0 : (ins (i32 8), Src0VT:$src0), - IsIU : (ins i32:$src0_modifiers, Src0VT:$src0), + IsIU : (ins (VOP3PModsNeg $src0_modifiers), Src0VT:$src0), NoABMods : (ins Src0VT:$src0)); - dag Src1InPat = !cond(IsAB_F32F64_IMod1 : (ins (VOP3PModsNeg i32:$src1_modifiers), Src1VT:$src1), - IsAB_F16BF16_IMod1 : (ins (VOP3PModsNegs i32:$src1_modifiers), Src1VT:$src1), + dag Src1InPat = !cond(IsAB_F32F64_IMod1 : (ins timm:$src1_modifiers, Src1VT:$src1), + IsAB_F16BF16_IMod1 : (ins timm:$src1_modifiers, Src1VT:$src1), IsAB_F16_IMod0 : (ins (Src1VT (WMMAModsF16Neg Src1VT:$src1, i32:$src1_modifiers))), IsAB_BF16_IMod0 : (ins Src1VT:$src1), - IsIU : (ins (VOP3PModsNeg i32:$src1_modifiers), Src1VT:$src1), + IsIU : (ins timm:$src1_modifiers, Src1VT:$src1), HasMatrixFMT : (ins timm:$matrix_b_fmt, Src1VT:$src1), NoABMods : (ins Src1VT:$src1)); - dag Src1OutPat = !cond(IsAB_F32F64_IMod1 : (ins i32:$src1_modifiers, Src1VT:$src1), - IsAB_F16BF16_IMod1 : (ins i32:$src1_modifiers, Src1VT:$src1), + dag Src1OutPat = !cond(IsAB_F32F64_IMod1 : (ins (VOP3PModsNeg $src1_modifiers), Src1VT:$src1), + IsAB_F16BF16_IMod1 : (ins (VOP3PModsNegs $src1_modifiers), Src1VT:$src1), IsAB_F16_IMod0 : (ins i32:$src1_modifiers, Src1VT:$src1), IsAB_BF16_IMod0 : (ins (i32 8), Src1VT:$src1), - IsIU : (ins i32:$src1_modifiers, Src1VT:$src1), + IsIU : (ins (VOP3PModsNeg $src1_modifiers), Src1VT:$src1), NoABMods : (ins Src1VT:$src1)); bit IsC_IMod1 = !and(HasIModOp, IsWMMA, !not(IsIU), !not(IsXF32)); bit IsC_F32_IMod0 = !and(IsC_F32, !not(HasIModOp)); bit IsC_F16_IMod0 = !and(IsC_F16, !not(HasIModOp)); bit IsC_BF16_IMod0 = !and(IsC_BF16, !not(HasIModOp)); bit IsIUXF32 = !or(IsIU, IsXF32); - dag Src2InPatWmma = !cond(IsC_IMod1 : (ins (VOP3PModsNegAbs i32:$src2_modifiers), Src2VT:$src2), + dag Src2InPatWmma = !cond(IsC_IMod1 : (ins timm:$src2_modifiers, Src2VT:$src2), IsC_F32_IMod0 : (ins (Src2VT (WMMAModsF32NegAbs Src2VT:$src2, i32:$src2_modifiers))), IsC_F16_IMod0 : (ins (Src2VT (WMMAModsF16NegAbs Src2VT:$src2, i32:$src2_modifiers))), IsC_BF16_IMod0 : (ins Src2VT:$src2), IsIUXF32 : (ins Src2VT:$src2), IsSWMMAC : (ins)); - dag Src2OutPatWmma = !cond(IsC_IMod1 : (ins i32:$src2_modifiers, Src2VT:$src2), + dag Src2OutPatWmma = !cond(IsC_IMod1 : (ins (VOP3PModsNegAbs $src2_modifiers), Src2VT:$src2), IsC_F32_IMod0 : (ins i32:$src2_modifiers, Src2VT:$src2), IsC_F16_IMod0 : (ins i32:$src2_modifiers, Src2VT:$src2), IsC_BF16_IMod0 : (ins (i32 8), Src2VT:$src2), @@ -1604,22 +1616,29 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType, !eq(IndexType, 16): (ins i32:$src2, i32:$index_key_16bit), !eq(IndexType, 32): (ins i64:$src2, i32:$index_key_32bit)); dag MatrixFMTOutPat = !if(HasMatrixFMT, (ins i32:$matrix_a_fmt, i32:$matrix_b_fmt), (ins)); - dag Src2InlineInPat = !con(!if(IsC_IMod1, (ins (VOP3PModsNegAbs i32:$src2_modifiers)), (ins)), (ins (Src2VT (WMMAVISrc Src2VT:$src2)))); - dag Src2InlineOutPat = !con(!if(IsIUXF32, (ins), !if(IsC_IMod1, (ins i32:$src2_modifiers), (ins (i32 8)))), (ins Src2VT:$src2)); + dag Src2InlineInPat = !con(!if(IsC_IMod1, (ins timm:$src2_modifiers), (ins)), (ins (Src2VT (WMMAVISrc Src2VT:$src2)))); + dag Src2InlineOutPat = !con(!if(IsIUXF32, (ins), !if(IsC_IMod1, (ins (VOP3PModsNegAbs $src2_modifiers)), (ins (i32 8)))), (ins Src2VT:$src2)); + dag MatrixScaleInPat = !if(HasMatrixScale, (ins timm:$matrix_a_scale, timm:$matrix_a_scale_fmt, ScaleTy:$scale_src0, + timm:$matrix_b_scale, timm:$matrix_b_scale_fmt, ScaleTy:$scale_src1), + (ins)); dag MatrixReuseInPat = !if(HasMatrixReuse, (ins timm:$matrix_a_reuse, timm:$matrix_b_reuse), (ins)); + dag MatrixScaleOutSrcPat = !if(HasMatrixScale, (ins ScaleTy:$scale_src0, ScaleTy:$scale_src1), (ins)); + dag MatrixScaleOutModPat = !if(HasMatrixScale, (ins i32:$matrix_a_scale, i32:$matrix_b_scale, i32:$matrix_a_scale_fmt, i32:$matrix_b_scale_fmt), (ins)); dag MatrixReuseOutModPat = !if(HasMatrixReuse, (ins i1:$matrix_a_reuse, i1:$matrix_b_reuse), (ins)); - dag WmmaInPat = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixReuseInPat, ClampPat); - dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat); + dag WmmaInPat = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixScaleInPat, MatrixReuseInPat, ClampPat); + dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixScaleOutSrcPat, MatrixFMTOutPat, + MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat); dag SwmmacInPat = !con(Src0InPat, Src1InPat, (ins Src2VT:$srcTiedDef), IndexInPat, MatrixReuseInPat, ClampPat); dag SwmmacOutPat = !con(Src0OutPat, Src1OutPat, (ins Src2VT:$srcTiedDef), IndexOutPat, MatrixReuseOutModPat, ClampPat); // wmma pattern where src2 is inline imm uses _threeaddr pseudo, // can't use _twoaddr since it would violate src2 tied to vdst constraint. - dag WmmaInlineInPat = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixReuseInPat, ClampPat); - dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat); + dag WmmaInlineInPat = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixScaleInPat, MatrixReuseInPat, ClampPat); + dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixScaleOutSrcPat, + MatrixFMTOutPat, MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat); } def WMMAInstInfoTable : GenericTable { @@ -1645,11 +1664,15 @@ multiclass WMMAInstGFX12<string Instr, VOP3PWMMA_Profile WMMAProfile, string Pse let Constraints = WMMAConstraints2Addr, isConvertibleToThreeAddress = 1 in def _twoaddr : VOP3P_Pseudo<Instr, WMMAProfile>, WMMAInstInfo { let PseudoInstr = Instr#PseudoInstrSuffix; + let FixedSize = WMMAProfile.HasMatrixScale; + let Size = !if(WMMAProfile.HasMatrixScale, 16, 8); } let Constraints = WMMAConstraints3Addr, SchedRW = [Write32Bit, Write32Bit] in def _threeaddr : VOP3P_Pseudo<Instr, WMMAProfile>, WMMAInstInfo { let PseudoInstr = Instr#PseudoInstrSuffix; + let FixedSize = WMMAProfile.HasMatrixScale; + let Size = !if(WMMAProfile.HasMatrixScale, 16, 8); } } @@ -1728,39 +1751,55 @@ def F32_FP8BF8_SWMMAC_w64 : VOP3PWMMA_Profile<[v4f32, i32, v2i32, v4f32], 1, // *** IU4X32_SWMMAC_w64 lanes 0-31 will have 8xi4 remaining lanes are ignored // for matrix A, index is i16; Matrix B uses all lanes -def F32_F32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v2f32, v2f32, v8f32], 0, 0, 0, 0, 1, 0, 1>; -def F32_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 1>; -def F32_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v16f16, v8f32], 0, 0, 0, 0, 1, 0, 1>; -def F16_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v16f16, v16f16, v8f16], 0, 0, 0, 0, 1, 0, 1>; -def BF16_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v16bf16, v8bf16], 0, 0, 0, 0, 1, 0, 1>; -def BF16F32_BF16_WMMA_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 1>; -def F32_FP8BF8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 0, 1>; -def F32_FP8BF8X128_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 0, 1>; -def F16_FP8BF8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v8i32, v8f16], 0, 0, 0, 1, 1, 0, 1>; -def F16_FP8BF8X128_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v16i32, v16i32, v8f16], 0, 0, 0, 1, 1, 0, 1>; -def F32_32X16X128_F4_WMMA_w32 : VOP3PWMMA_Profile<[v16f32, v16i32, v8i32, v16f32], 0, 0, 0, 0, 1, 0, 0, 1>; -def I32_IU8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v8i32, v8i32], 0, 0, 1, 0, 1, 0, 1>; -def F32_F16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v32f16, v8f32], 1, 16, 0, 0, 1, 0, 1>; -def F32_BF16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v32bf16, v8f32], 1, 16, 0, 0, 1, 0, 1>; -def F16_F16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v16f16, v32f16, v8f16], 1, 16, 0, 0, 1, 0, 1>; -def BF16_BF16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v32bf16, v8bf16], 1, 16, 0, 0, 1, 0, 1>; -def F32_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 1, 32, 0, 1, 1, 0, 1>; -def F16_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v16i32, v8f16], 1, 32, 0, 1, 1, 0, 1>; -def I32_IU8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v16i32, v8i32], 1, 32, 1, 0, 1, 0, 1>; - -multiclass WMMA_F8F6F4_Profiles<bit HasMatrixReuse> { - def _f8_f8_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; - def _f8_f6_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; - def _f8_f4_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; - def _f6_f8_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; - def _f6_f6_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; - def _f6_f4_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; - def _f4_f8_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; - def _f4_f6_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; - def _f4_f4_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixReuse>; -} - -defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles<0>; +def F32_F32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v2f32, v2f32, v8f32], 0, 0, 0, 0, 1, 0, 0, 0, 1>; +def F32_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 0, 0, 1>; +def F32_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v16f16, v8f32], 0, 0, 0, 0, 1, 0, 0, 0, 1>; +def F16_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v16f16, v16f16, v8f16], 0, 0, 0, 0, 1, 0, 0, 0, 1>; +def BF16_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v16bf16, v8bf16], 0, 0, 0, 0, 1, 0, 0, 0, 1>; +def BF16F32_BF16_WMMA_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 0, 0, 1>; +def F32_FP8BF8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 0, 0, 0, 1>; +def F32_FP8BF8X128_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 0, 0, 0, 1>; +def F16_FP8BF8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v8i32, v8f16], 0, 0, 0, 1, 1, 0, 0, 0, 1>; +def F16_FP8BF8X128_WMMA_w32 : VOP3PWMMA_Profile<[v8f16, v16i32, v16i32, v8f16], 0, 0, 0, 1, 1, 0, 0, 0, 1>; +def F32_32X16X128_F4_WMMA_w32 : VOP3PWMMA_Profile<[v16f32, v16i32, v8i32, v16f32], 0, 0, 0, 0, 1, 0, 0, 0, 0, 1>; +def I32_IU8X64_WMMA_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v8i32, v8i32], 0, 0, 1, 0, 1, 0, 0, 0, 1>; +def F32_32X16X128_F4_SCALE_w32 : VOP3PWMMA_Profile<[v16f32, v16i32, v8i32, v16f32], 0, 0, 0, 1, 1, 0, 1, 0, 1>; +def F32_32X16X128_F4_SCALE16_w32 : VOP3PWMMA_Profile<[v16f32, v16i32, v8i32, v16f32], 0, 0, 0, 1, 1, 0, 1, 1, 1>; +def F32_F16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v32f16, v8f32], 1, 16, 0, 0, 1, 0, 0, 0, 1>; +def F32_BF16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v32bf16, v8f32], 1, 16, 0, 0, 1, 0, 0, 0, 1>; +def F16_F16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v16f16, v32f16, v8f16], 1, 16, 0, 0, 1, 0, 0, 0, 1>; +def BF16_BF16X64_SWMMAC_w32 : VOP3PWMMA_Profile<[v8bf16, v16bf16, v32bf16, v8bf16], 1, 16, 0, 0, 1, 0, 0, 0, 1>; +def F32_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 1, 32, 0, 1, 1, 0, 0, 0, 1>; +def F16_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v16i32, v8f16], 1, 32, 0, 1, 1, 0, 0, 0, 1>; +def I32_IU8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v16i32, v8i32], 1, 32, 1, 0, 1, 0, 0, 0, 1>; + +multiclass WMMA_F8F6F4_Profiles<bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> { + def _f8_f8_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + def _f8_f6_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + def _f8_f4_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + def _f6_f8_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + def _f6_f6_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + def _f6_f4_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + def _f4_f8_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + def _f4_f6_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + def _f4_f4_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; +} + +defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles<0, 0, 0>; +defm F32_16X16X128_F8F6F4_SCALE : WMMA_F8F6F4_Profiles<1, 0, 1>; +defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<1, 1, 1>; + +class VOP_WMMA_LD_SCALE<ValueType vt, RegisterOperand RC> : VOP3P_Profile<VOPProfile<[untyped, vt, vt, untyped]>> { + let HasMatrixScale = 1; + let HasMatrixReuse = 1; + let HasNeg = 0; + let Src0RC64 = RC; + let Src1RC64 = RC; + let Ins64 = (ins Src0RC64:$src0, Src1RC64:$src1, MatrixAScale:$matrix_a_scale, MatrixBScale:$matrix_b_scale, + MatrixAScaleFmt:$matrix_a_scale_fmt, MatrixBScaleFmt:$matrix_b_scale_fmt, + MatrixAReuse:$matrix_a_reuse, MatrixBReuse:$matrix_b_reuse); + let AsmVOP3P = " $src0, $src1$matrix_a_scale$matrix_b_scale$matrix_a_scale_fmt$matrix_b_scale_fmt$matrix_a_reuse$matrix_b_reuse"; +} multiclass WMMAInst_SrcFormats_mc<string OpName, string Profile> { foreach I = ["f8_f8", "f8_f6", "f8_f4", "f6_f8", "f6_f6", "f6_f4", "f4_f8", "f4_f6", "f4_f4"] in { @@ -1813,9 +1852,15 @@ defm V_SWMMAC_F32_16X16X64_F16_w32 : SWMMACInstGFX12<"v_swmmac_f32_16x16x64 defm V_SWMMAC_F16_16X16X64_F16_w32 : SWMMACInstGFX12<"v_swmmac_f16_16x16x64_f16", F16_F16X64_SWMMAC_w32, "_w32">; defm V_WMMA_F32_16X16X128_F8F6F4 : WMMAInst_SrcFormats_mc<"v_wmma_f32_16x16x128_f8f6f4", "F32_16X16X128_F8F6F4">; +defm V_WMMA_SCALE_F32_16X16X128_F8F6F4 : WMMAInst_SrcFormats_mc<"v_wmma_scale_f32_16x16x128_f8f6f4", "F32_16X16X128_F8F6F4_SCALE">; +defm V_WMMA_SCALE16_F32_16X16X128_F8F6F4 : WMMAInst_SrcFormats_mc<"v_wmma_scale16_f32_16x16x128_f8f6f4", "F32_16X16X128_F8F6F4_SCALE16">; +defm V_WMMA_SCALE_F32_32X16X128_F4_w32 : WMMAInstGFX12<"v_wmma_scale_f32_32x16x128_f4", F32_32X16X128_F4_SCALE_w32, "_w32">; +defm V_WMMA_SCALE16_F32_32X16X128_F4_w32 : WMMAInstGFX12<"v_wmma_scale16_f32_32x16x128_f4", F32_32X16X128_F4_SCALE16_w32, "_w32">; } // End is_wmma_xdl = 1. +defm V_WMMA_LD_SCALE_PAIRED_B32 : VOP3PInst<"v_wmma_ld_scale_paired_b32", VOP_WMMA_LD_SCALE<i32, VCSrc_b32>>; +defm V_WMMA_LD_SCALE16_PAIRED_B64 : VOP3PInst<"v_wmma_ld_scale16_paired_b64", VOP_WMMA_LD_SCALE<i64, VCSrc_b64>>; } // End SubtargetPredicate = isGFX125xOnly } // End WaveSizePredicate = isWave32 @@ -1970,9 +2015,13 @@ let SubtargetPredicate = isGFX125xOnly in { defm : WMMAPat<"V_WMMA_F32_16X16X128_BF8_FP8_w32", int_amdgcn_wmma_f32_16x16x128_bf8_fp8, F32_FP8BF8X128_WMMA_w32>; defm : WMMAPat<"V_WMMA_F32_16X16X128_BF8_BF8_w32", int_amdgcn_wmma_f32_16x16x128_bf8_bf8, F32_FP8BF8X128_WMMA_w32>; defm : WMMAPat<"V_WMMA_F32_32X16X128_F4_w32", int_amdgcn_wmma_f32_32x16x128_f4, F32_32X16X128_F4_WMMA_w32>; + defm : WMMAPat<"V_WMMA_SCALE_F32_32X16X128_F4_w32", int_amdgcn_wmma_scale_f32_32x16x128_f4, F32_32X16X128_F4_SCALE_w32>; + defm : WMMAPat<"V_WMMA_SCALE16_F32_32X16X128_F4_w32", int_amdgcn_wmma_scale16_f32_32x16x128_f4, F32_32X16X128_F4_SCALE16_w32>; foreach I = ["f8_f8", "f8_f6", "f8_f4", "f6_f8", "f6_f6", "f6_f4", "f4_f8", "f4_f6", "f4_f4"] in { defm : WMMAPat<"V_WMMA_F32_16X16X128_F8F6F4_" # I # "_w32", int_amdgcn_wmma_f32_16x16x128_f8f6f4, !cast<VOP3PWMMA_Profile>("F32_16X16X128_F8F6F4_" # I # "_w32")>; + defm : WMMAPat<"V_WMMA_SCALE_F32_16X16X128_F8F6F4_" # I # "_w32", int_amdgcn_wmma_scale_f32_16x16x128_f8f6f4, !cast<VOP3PWMMA_Profile>("F32_16X16X128_F8F6F4_SCALE_" # I # "_w32")>; + defm : WMMAPat<"V_WMMA_SCALE16_F32_16X16X128_F8F6F4_" # I # "_w32", int_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4, !cast<VOP3PWMMA_Profile>("F32_16X16X128_F8F6F4_SCALE16_" # I # "_w32")>; } def : SWMMACPat<V_SWMMAC_F32_16X16X64_BF16_w32_twoaddr, int_amdgcn_swmmac_f32_16x16x64_bf16, F32_BF16X64_SWMMAC_w32>; @@ -2105,6 +2154,82 @@ multiclass VOP3P_Real_WMMA_gfx1250_SrcFormats<bits<8> op, string WMMAP> { } } +class VOP3PX2e <bits<8> op, bits<8> LdScaleOp, VOP3PWMMA_Profile P> : Enc128, VOP3Pe_Base { + bits<9> scale_src0; + bits<9> scale_src1; + + // Inst{7-0} = unused + let Inst{10-8} = {0, matrix_b_scale_fmt{1-0}}; // neg_hi + let Inst{11} = matrix_a_scale{0}; // scale_op_sel(0) + let Inst{12} = 0; // scale_op_sel(1) + let Inst{13} = matrix_a_reuse; // scale_op_sel(2) + let Inst{14} = matrix_b_reuse; // scale_op_sel_hi(2) + let Inst{15} = 0; // scale_clamp + let Inst{31-24} = 0xcc; // Encoding + let Inst{23-16} = LdScaleOp; + let Inst{40-32} = scale_src0; + let Inst{49-41} = scale_src1; + let Inst{58-50} = 0; // scale src2 + let Inst{59} = matrix_b_scale{0}; // scale_op_sel_hi(0) + let Inst{60} = 0; // scale_op_sel_hi(1) + let Inst{63-61} = {0, matrix_a_scale_fmt{1-0}}; // neg (lo) + + // The high half of the encoding is the unscaled wmma op. + let Inst{71-64} = vdst; + + let Inst{72} = !if(P.NegHi01, src0_modifiers{1}, 0); // neg_hi src0 + let Inst{73} = !if(P.NegHi01, src1_modifiers{1}, 0); // neg_hi src1 + let Inst{74} = !if(P.NegHi2, src2_modifiers{1}, 0); // neg_hi src2 + + let Inst{77-75} = !if(P.HasMatrixFMT, matrix_a_fmt{2-0}, 0); // op_sel + + let Inst{78,124,123} = !if(P.HasMatrixFMT, matrix_b_fmt{2-0}, 7); // op_sel_hi + let Inst{79} = !if(P.HasClamp, clamp{0}, 0); + + let Inst{87-80} = op; + let Inst{95-88} = 0xcc; //encoding + let Inst{104-96} = !if(P.HasSrc0, src0, 0); + let Inst{113-105} = !if(P.HasSrc1, src1, 0); + let Inst{122-114} = !if(P.HasSrc2, src2, 0); + + // neg_lo + let Inst{125} = !if(P.NegLo01, src0_modifiers{0}, 0); + let Inst{126} = !if(P.NegLo01, src1_modifiers{0}, 0); + let Inst{127} = !if(P.NegLo2, src2_modifiers{0}, 0); +} + +multiclass VOP3PX2_Real_ScaledWMMA_F4<bits<8> op, bits<8> LdScaleOp, VOP3PWMMA_Profile WMMAP> { + defvar PS = !cast<VOP3P_Pseudo>(NAME # "_twoaddr"); + let SubtargetPredicate = isGFX1250Plus, WaveSizePredicate = isWave32, + DecoderNamespace = "GFX1250" in { + def _gfx1250 : VOP3P_Real_Gen<PS, GFX1250Gen, PS.Mnemonic>, + VOP3PX2e <op, LdScaleOp, WMMAP>; + } +} + +multiclass VOP3PX2_Real_ScaledWMMA<bits<8> op, bits<8> LdScaleOp, VOP3PWMMA_Profile WMMAP> { + defvar PS = !cast<VOP3P_Pseudo>(NAME # "_twoaddr"); + defvar asmName = !substr(PS.Mnemonic, 0, !sub(!size(PS.Mnemonic), !size("_f8_f8_w32"))); + defvar psName = !substr(NAME, 0, !sub(!size(PS.Mnemonic), !size("_f8_f8_w32"))); + let SubtargetPredicate = isGFX1250Plus, WaveSizePredicate = isWave32, + DecoderNamespace = "GFX1250" in { + def _gfx1250 : VOP3P_Real_Gen<PS, GFX1250Gen, asmName>, + VOP3PX2e <op, LdScaleOp, WMMAP>, + MFMA_F8F6F4_WithSizeTable_Helper<PS, psName # "_f8_f8_w32_gfx1250"> { + let AsmString = asmName # PS.AsmOperands; + } + } +} + +multiclass VOP3PX2_Real_ScaledWMMA_SrcFormats<bits<8> op, bits<8> LdScaleOp, string WMMAP> { + defm _f8_f8_w32 : VOP3PX2_Real_ScaledWMMA<op, LdScaleOp, !cast<VOP3PWMMA_Profile>(WMMAP # "_f8_f8_w32")>; + foreach I = ["f8_f6", "f8_f4", "f6_f8", "f6_f6", "f6_f4", "f4_f8", "f4_f6", "f4_f4"] in { + let isAsmParserOnly = true in { // Disable ambiguous disassembly. + defm _#I#_w32 : VOP3PX2_Real_ScaledWMMA<op, LdScaleOp, !cast<VOP3PWMMA_Profile>(WMMAP # "_" # I # "_w32")>; + } + } +} + defm V_WMMA_F32_16X16X16_F16_w32 : VOP3P_Real_WMMA_gfx12 <0x040, F32_F16_WMMA_w32>; defm V_WMMA_F32_16X16X16_BF16_w32 : VOP3P_Real_WMMA_gfx12 <0x041, F32_BF16_WMMA_w32>; defm V_WMMA_F16_16X16X16_F16_w32 : VOP3P_Real_WMMA_gfx12 <0x042, F16_F16_WMMA_w32>; @@ -2180,6 +2305,11 @@ defm V_WMMA_F16_16X16X128_BF8_BF8_w32 : VOP3P_Real_WMMA_gfx1250 <0x087, F16_FP8B defm V_WMMA_F32_32X16X128_F4_w32 : VOP3P_Real_WMMA_gfx1250 <0x088, F32_32X16X128_F4_WMMA_w32>; defm V_WMMA_F32_16X16X128_F8F6F4 : VOP3P_Real_WMMA_gfx1250_SrcFormats<0x033, "F32_16X16X128_F8F6F4">; +defm V_WMMA_SCALE_F32_16X16X128_F8F6F4 : VOP3PX2_Real_ScaledWMMA_SrcFormats<0x033, 0x35, "F32_16X16X128_F8F6F4_SCALE">; +defm V_WMMA_SCALE16_F32_16X16X128_F8F6F4 : VOP3PX2_Real_ScaledWMMA_SrcFormats<0x033, 0x3a, "F32_16X16X128_F8F6F4_SCALE16">; + +defm V_WMMA_SCALE_F32_32X16X128_F4_w32 : VOP3PX2_Real_ScaledWMMA_F4<0x088, 0x35, F32_32X16X128_F4_SCALE_w32>; +defm V_WMMA_SCALE16_F32_32X16X128_F4_w32 : VOP3PX2_Real_ScaledWMMA_F4<0x088, 0x3a, F32_32X16X128_F4_SCALE16_w32>; defm V_SWMMAC_F32_16X16X64_F16_w32 : VOP3P_Real_WMMA_gfx1250 <0x065, F32_F16X64_SWMMAC_w32>; defm V_SWMMAC_F32_16X16X64_BF16_w32 : VOP3P_Real_WMMA_gfx1250 <0x066, F32_BF16X64_SWMMAC_w32>; @@ -2283,6 +2413,9 @@ defm V_FMA_MIX_F32_BF16 : VOP3P_Realtriple<GFX1250Gen, 0x3d>; defm V_FMA_MIXLO_BF16 : VOP3P_Realtriple<GFX1250Gen, 0x3e>; defm V_FMA_MIXHI_BF16 : VOP3P_Realtriple<GFX1250Gen, 0x3f>; +defm V_WMMA_LD_SCALE_PAIRED_B32 : VOP3P_Real_gfx1250<0x35>; +defm V_WMMA_LD_SCALE16_PAIRED_B64 : VOP3P_Real_gfx1250<0x3a>; + let AssemblerPredicate = isGFX1250Plus in def : AMDGPUMnemonicAlias<"v_fma_mix_f32_f16", "v_fma_mix_f32">; diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td index f027ab0..3cad5a1 100644 --- a/llvm/lib/Target/AMDGPU/VOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td @@ -475,17 +475,24 @@ class VOP3Pe_Base { bits<1> index_key_32bit; bits<3> matrix_a_fmt; bits<3> matrix_b_fmt; + bits<1> matrix_a_scale; + bits<1> matrix_b_scale; + bits<2> matrix_a_scale_fmt; + bits<2> matrix_b_scale_fmt; bits<1> matrix_a_reuse; bits<1> matrix_b_reuse; } class VOP3Pe <VOPProfile P> : Enc64, VOP3Pe_Base { let Inst{7-0} = !if(P.HasDst, vdst, 0); - let Inst{8} = !if(P.HasSrc0Mods, src0_modifiers{1}, 0); // neg_hi src0 - let Inst{9} = !if(P.HasSrc1Mods, src1_modifiers{1}, 0); // neg_hi src1 + let Inst{8} = !if(P.HasSrc0Mods, src0_modifiers{1}, + !if(P.HasMatrixScale, matrix_b_scale_fmt{0}, 0)); // neg_hi src0 + let Inst{9} = !if(P.HasSrc1Mods, src1_modifiers{1}, + !if(P.HasMatrixScale, matrix_b_scale_fmt{1}, 0)); // neg_hi src1 let Inst{10} = !if(P.HasSrc2Mods, src2_modifiers{1}, 0); // neg_hi src2 - let Inst{11} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{2}, 0); // op_sel(0) + let Inst{11} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{2}, + !if(P.HasMatrixScale, matrix_a_scale{0}, 0)); // op_sel(0) let Inst{12} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{2}, 0); // op_sel(1) let Inst{13} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{2}, !if(P.HasMatrixReuse, matrix_a_reuse, 0)); // op_sel(2) @@ -500,10 +507,17 @@ class VOP3Pe <VOPProfile P> : Enc64, VOP3Pe_Base { let Inst{40-32} = !if(P.HasSrc0, src0, 0); let Inst{49-41} = !if(P.HasSrc1, src1, 0); let Inst{58-50} = !if(P.HasSrc2, src2, 0); - let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(0) - let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(1) - let Inst{61} = !if(P.HasSrc0Mods, src0_modifiers{0}, 0); // neg (lo) - let Inst{62} = !if(P.HasSrc1Mods, src1_modifiers{0}, 0); // neg (lo) + let Inst{59} = !cond(!and(P.HasSrc0, P.HasOpSel) : src0_modifiers{3}, + P.IsDOT : 1, + P.HasMatrixScale : matrix_b_scale{0}, + 1: ?); // op_sel_hi(0) + let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, + !if(P.HasMatrixScale, 0, + !if(P.IsDOT, 1, ?))); // op_sel_hi(1) + let Inst{61} = !if(P.HasSrc0Mods, src0_modifiers{0}, + !if(P.HasMatrixScale, matrix_a_scale_fmt{0}, 0)); // neg (lo) + let Inst{62} = !if(P.HasSrc1Mods, src1_modifiers{0}, + !if(P.HasMatrixScale, matrix_a_scale_fmt{1}, 0)); // neg (lo) let Inst{63} = !if(P.HasSrc2Mods, src2_modifiers{0}, 0); // neg (lo) } |