diff options
Diffstat (limited to 'mlir/lib/Conversion')
| -rw-r--r-- | mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 70 |
1 files changed, 40 insertions, 30 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 9b15435..478b6aa 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" @@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { /// on the architecture you are compiling for. static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset) { - auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType()); - auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType()); - auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType()); - auto elemSourceType = sourceVectorType.getElementType(); - auto elemBSourceType = sourceBVectorType.getElementType(); - auto elemDestType = destVectorType.getElementType(); - - if (elemSourceType.isF16() && elemDestType.isF32()) - return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); - if (elemSourceType.isBF16() && elemDestType.isF32()) - return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); - if (elemSourceType.isF16() && elemDestType.isF16()) - return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); - if (elemSourceType.isBF16() && elemDestType.isBF16()) - return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); - if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) - return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (chipset.majorVersion == 11) { - if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) - return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); + auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); + auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); + Type elemSourceType = sourceVectorType.getElementType(); + Type elemBSourceType = sourceBVectorType.getElementType(); + Type elemDestType = destVectorType.getElementType(); + + const uint32_t k = wmma.getK(); + + if (k == 16) { + if (elemSourceType.isF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); + if (elemSourceType.isF16() && elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isBF16()) + return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); + if (chipset.majorVersion == 11) { + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + } } - if (chipset.majorVersion >= 12) { + if (chipset.majorVersion < 12) + return std::nullopt; + + // gfx12+ + if (k == 16) { if (isa<Float8E4M3FNType>(elemSourceType) && isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); @@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, if (isa<Float8E5M2Type>(elemSourceType) && isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); - if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) { - bool isWave64 = destVectorType.getNumElements() == 4; - // This is the ambiguous case. 8 inputs to the wave64 version means that - // we want the 16x16x32 version, but for wave32 they mean the short form. - bool has8Inputs = sourceVectorType.getNumElements() == 8; - if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs)) - return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); - } + + return std::nullopt; } - return std::nullopt; + if (k == 32) { + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + return std::nullopt; + } + + llvm_unreachable("unhandled WMMA case"); } namespace { |
