aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp175
1 files changed, 146 insertions, 29 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aa..1eca43d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
smfma.getN(), smfma.getK(), 1u, chipset);
}
-/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
-/// if one exists. This includes checking to ensure the intrinsic is supported
-/// on the architecture you are compiling for.
-static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
- Chipset chipset) {
- auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
- auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
- auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
- Type elemSourceType = sourceVectorType.getElementType();
- Type elemBSourceType = sourceBVectorType.getElementType();
- Type elemDestType = destVectorType.getElementType();
-
- const uint32_t k = wmma.getK();
-
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for RDNA3/4 architectures.
+static std::optional<StringRef>
+wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
+ Type elemDestType, uint32_t k, bool isRDNA3) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
+ // Handle k == 16 for RDNA3/4.
if (k == 16) {
+ // Common patterns for RDNA3 and RDNA4.
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1014,39 +1010,160 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (chipset.majorVersion == 11) {
+
+ // RDNA3 specific patterns.
+ if (isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ return std::nullopt;
}
- }
- if (chipset.majorVersion < 12)
- return std::nullopt;
- // gfx12+
- if (k == 16) {
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ // RDNA4 specific patterns (fp8/bf8).
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
return std::nullopt;
}
- if (k == 32) {
+
+ // Handle k == 32 for RDNA4.
+ if (k == 32 && !isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ }
+
+ llvm_unreachable("Unsupported k value");
+}
+
+/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for the gfx1250 architecture.
+static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
+ Type elemBSourceType,
+ Type elemDestType,
+ uint32_t k) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
+ if (k == 4) {
+ if (elemSourceType.isF32() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
+
return std::nullopt;
}
+ if (k == 32) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 64) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+ }
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 128) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+ }
+
+ return std::nullopt;
+ }
+
+ llvm_unreachable("Unsupported k value");
+}
+
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
+ Chipset chipset) {
+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+ auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+ Type elemSourceType = sourceVectorType.getElementType();
+ Type elemBSourceType = sourceBVectorType.getElementType();
+ Type elemDestType = destVectorType.getElementType();
+
+ const uint32_t k = wmma.getK();
+ const bool isRDNA3 = chipset.majorVersion == 11;
+ const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
+
+ // Handle RDNA3 and RDNA4.
+ if (isRDNA3 || isRDNA4)
+ return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
+ k, isRDNA3);
+
+ // Handle gfx1250.
+ if (chipset == Chipset{12, 5, 0})
+ return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
+ elemDestType, k);
+
llvm_unreachable("unhandled WMMA case");
}