aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp')
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp41
1 files changed, 41 insertions, 0 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e89..f2207ff 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
NewII->takeName(&II);
return IC.replaceInstUsesWith(II, NewII);
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = II.getArgOperand(1);
+ Value *Src1 = II.getArgOperand(3);
+ unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+ uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+ auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+ bool MadeChange = false;
+ unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+ unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+ // Depending on the used format, fewer registers are required so shrink the
+ // vector type.
+ if (Src0Ty->getNumElements() > Src0NumElts) {
+ Src0 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (Src1Ty->getNumElements() > Src1NumElts) {
+ Src1 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (!MadeChange)
+ return std::nullopt;
+
+ SmallVector<Value *, 13> Args(II.args());
+ Args[1] = Src0;
+ Args[3] = Src1;
+
+ CallInst *NewII = IC.Builder.CreateIntrinsic(
+ IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+ Args, &II);
+ NewII->takeName(&II);
+ return IC.replaceInstUsesWith(II, NewII);
+ }
}
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {