diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp')
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp index e2c2e89..f2207ff 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp @@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { NewII->takeName(&II); return IC.replaceInstUsesWith(II, NewII); } + case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: { + Value *Src0 = II.getArgOperand(1); + Value *Src1 = II.getArgOperand(3); + unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue(); + uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue(); + auto *Src0Ty = cast<FixedVectorType>(Src0->getType()); + auto *Src1Ty = cast<FixedVectorType>(Src1->getType()); + + bool MadeChange = false; + unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA); + unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB); + + // Depending on the used format, fewer registers are required so shrink the + // vector type. + if (Src0Ty->getNumElements() > Src0NumElts) { + Src0 = IC.Builder.CreateExtractVector( + FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0, + IC.Builder.getInt64(0)); + MadeChange = true; + } + + if (Src1Ty->getNumElements() > Src1NumElts) { + Src1 = IC.Builder.CreateExtractVector( + FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1, + IC.Builder.getInt64(0)); + MadeChange = true; + } + + if (!MadeChange) + return std::nullopt; + + SmallVector<Value *, 13> Args(II.args()); + Args[1] = Src0; + Args[3] = Src1; + + CallInst *NewII = IC.Builder.CreateIntrinsic( + IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()}, + Args, &II); + NewII->takeName(&II); + return IC.replaceInstUsesWith(II, NewII); + } } if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr = AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) { |