aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp80
1 files changed, 45 insertions, 35 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 85f0fd1d..478b6aa 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
- auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
- auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
- auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
- auto elemSourceType = sourceVectorType.getElementType();
- auto elemBSourceType = sourceBVectorType.getElementType();
- auto elemDestType = destVectorType.getElementType();
-
- if (elemSourceType.isF16() && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
- if (elemSourceType.isBF16() && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
- if (elemSourceType.isF16() && elemDestType.isF16())
- return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
- if (elemSourceType.isBF16() && elemDestType.isBF16())
- return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
- if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (chipset.majorVersion == 11) {
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+ auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+ Type elemSourceType = sourceVectorType.getElementType();
+ Type elemBSourceType = sourceBVectorType.getElementType();
+ Type elemDestType = destVectorType.getElementType();
+
+ const uint32_t k = wmma.getK();
+
+ if (k == 16) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+ if (chipset.majorVersion == 11) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ }
}
- if (chipset.majorVersion >= 12) {
+ if (chipset.majorVersion < 12)
+ return std::nullopt;
+
+ // gfx12+
+ if (k == 16) {
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
- bool isWave64 = destVectorType.getNumElements() == 4;
- // This is the ambiguous case. 8 inputs to the wave64 version means that
- // we want the 16x16x32 version, but for wave32 they mean the short form.
- bool has8Inputs = sourceVectorType.getNumElements() == 8;
- if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
- return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
- }
+
+ return std::nullopt;
}
- return std::nullopt;
+ if (k == 32) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ return std::nullopt;
+ }
+
+ llvm_unreachable("unhandled WMMA case");
}
namespace {
@@ -1927,16 +1937,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
else
llvm_unreachable("unsupported row length");
- const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
- const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
+ Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
- const Value isEqual =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
+ Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
+ LLVM::ICmpPredicate::eq, vdst0, v);
// Per `permlane(16|32)` semantics: if the first extracted element equals
// 'v', the result is the second element; otherwise it is the first.
Value vdstNew =
- rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
+ LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
permuted.emplace_back(vdstNew);
}