diff options
author | Guray Ozen <guray.ozen@gmail.com> | 2024-01-22 08:37:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-22 08:37:20 +0100 |
commit | 12c241b3654800ab708607dbc1998975c893fc14 (patch) | |
tree | 21b157d58b5d3603a8341c9c4c7a57e0a00ef0dc /mlir | |
parent | 21830c913505b1fd2cf10e454253483180c7e10b (diff) | |
download | llvm-12c241b3654800ab708607dbc1998975c893fc14.zip llvm-12c241b3654800ab708607dbc1998975c893fc14.tar.gz llvm-12c241b3654800ab708607dbc1998975c893fc14.tar.bz2 |
[MLIR][NVVM] Explicit Data Type for Output in `wgmma.mma_async` (#78713)
The current implementation of `nvvm.wgmma.mma_async` Op deduces the data
type of the output matrix from the data type of struct member, which can be
non-intuitive, especially in cases where types like `2xf16` are packed
into `i32`.
This PR addresses this issue by improving the Op to include an explicit
data type for the output matrix.
The modified Op now includes an explicit data type for Matrix-D (<f16>),
and looks as follows:
```
%result = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, ...
nvvm.wgmma.mma_async
%descA, %descB, %result,
#nvvm.shape<m = 64, n = 32, k = 16>,
D [<f16>, #nvvm.wgmma_scale_out<zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
```
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 10 | ||||
-rw-r--r-- | mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 15 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 117 | ||||
-rw-r--r-- | mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 18 | ||||
-rw-r--r-- | mlir/test/Conversion/NVVMToLLVM/invalid.mlir | 42 | ||||
-rw-r--r-- | mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 56 | ||||
-rw-r--r-- | mlir/test/python/dialects/nvvm.py | 3 |
7 files changed, 135 insertions, 126 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 7140e61..b1bd3a9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1833,11 +1833,14 @@ def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>; def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>; def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>; def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>; +def WGMMATypeF32 : I32EnumAttrCase<"f32", 8>; +def WGMMATypeS32 : I32EnumAttrCase<"s32", 9>; + def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types", [WGMMATypeF16, WGMMATypeTF32, WGMMATypeU8, WGMMATypeS8, WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3, - WGMMATypeF8E5M2]> { + WGMMATypeF8E5M2, WGMMATypeF32, WGMMATypeS32]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } @@ -1859,6 +1862,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", NVVM_MMAShapeAttr:$shape, WGMMATypesAttr:$typeA, WGMMATypesAttr:$typeB, + WGMMATypesAttr:$typeD, WGMMAScaleOutAttr:$scaleD, WGMMAScaleInAttr:$scaleA, WGMMAScaleInAttr:$scaleB, @@ -1868,8 +1872,8 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", ); let assemblyFormat = [{ - $descriptorA `,` $descriptorB `,` $shape `,` - `D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,` + $descriptorA `,` $descriptorB `,` $inouts `,` $shape `,` + `D` `[` $typeD `,` $scaleD (`,` $satfinite^)? `]` `,` `A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,` `B` `[` $typeB `,` $scaleB `,` $layoutB `]` attr-dict `:` diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index ab4dea9..43d05b8 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1267,10 +1267,11 @@ struct NVGPUWarpgroupMmaOpLowering } /// Generates WGMMATypesAttr from MLIR Type - NVVM::WGMMATypesAttr generateWgmmaType(Type type) const { - auto getWgmmaType = [](Type elemType) { + NVVM::WGMMATypesAttr generateWgmmaType(Type type, + bool useF32 = false) const { + auto getWgmmaType = [=](Type elemType) { if (elemType.isF32() || elemType.isTF32()) - return NVVM::WGMMATypes::tf32; + return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32; if (elemType.isF16()) return NVVM::WGMMATypes::f16; if (elemType.isBF16()) @@ -1285,6 +1286,8 @@ struct NVGPUWarpgroupMmaOpLowering return NVVM::WGMMATypes::s8; if (elemType.isUnsignedInteger(8)) return NVVM::WGMMATypes::u8; + if (elemType.isInteger(32)) + return NVVM::WGMMATypes::s32; llvm_unreachable("unsupported type"); }; return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type)); @@ -1397,6 +1400,9 @@ struct NVGPUWarpgroupMmaOpLowering Type elemB = op.getDescriptorB().getType().getTensor().getElementType(); NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB); + Type elemD = op.getMatrixC().getType().getFragmented().getElementType(); + NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true); + NVVM::MMAShapeAttr shape = generateWgmmaShape(); NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut(); NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn(); @@ -1408,7 +1414,8 @@ struct NVGPUWarpgroupMmaOpLowering return b.create<NVVM::WgmmaMmaAsyncOp>( matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, - itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); + itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, + overflow); } /// Generates multiple wgmma instructions to complete the given GEMM shape diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index aa49c4d..a855e4b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -755,37 +755,44 @@ FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) { return failure(); } -LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA, +LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, + NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB) { switch (typeA) { case NVVM::WGMMATypes::f16: - if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16) + if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && + typeB == NVVM::WGMMATypes::f16) return success(); break; case NVVM::WGMMATypes::tf32: - if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32) + if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32) return success(); break; case NVVM::WGMMATypes::u8: case NVVM::WGMMATypes::s8: - if (typeD.isInteger(32) && + if (typeD == NVVM::WGMMATypes::s32 && (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) return success(); break; case NVVM::WGMMATypes::b1: - if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1) + if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1) return success(); break; case NVVM::WGMMATypes::bf16: - if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16) + if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && + typeB == NVVM::WGMMATypes::bf16) return success(); break; case NVVM::WGMMATypes::e4m3: case NVVM::WGMMATypes::e5m2: - if ((typeD.isF32() || typeD.isF16()) && + if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) return success(); break; + case WGMMATypes::f32: + case WGMMATypes::s32: + llvm_unreachable("unsupported input types"); + break; } return failure(); } @@ -799,19 +806,24 @@ LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256}; switch (typeA) { - case mlir::NVVM::WGMMATypes::f16: - case mlir::NVVM::WGMMATypes::tf32: - case mlir::NVVM::WGMMATypes::bf16: - case mlir::NVVM::WGMMATypes::e4m3: - case mlir::NVVM::WGMMATypes::e5m2: + case WGMMATypes::f16: + case WGMMATypes::tf32: + case WGMMATypes::bf16: + case WGMMATypes::e4m3: + case WGMMATypes::e5m2: if (llvm::is_contained(allowedN, sizeN)) return success(); break; - case mlir::NVVM::WGMMATypes::u8: - case mlir::NVVM::WGMMATypes::s8: - case mlir::NVVM::WGMMATypes::b1: + case WGMMATypes::u8: + case WGMMATypes::s8: + case WGMMATypes::b1: if (llvm::is_contained(allowedNshort, sizeN)) return success(); + break; + case WGMMATypes::f32: + case WGMMATypes::s32: + llvm_unreachable("unsupported input types"); + break; } return failure(); } @@ -821,27 +833,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); if (!stype) return emitOpError() << "expected results to be struct"; - Type outputType = stype.getBody().front(); int outputSize = stype.getBody().size(); + WGMMATypes typeD = getTypeD(); + WGMMATypes typeA = getTypeA(); + WGMMATypes typeB = getTypeB(); + for (Type t : stype.getBody()) { - if (t != outputType) + if (t != stype.getBody().front()) return emitOpError() << "all elements in struct must be same type but there is " << t; } - if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) { + if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 && + typeD != WGMMATypes::s32) { return emitOpError() << "does not support the given output type " - << outputType; + << NVVM::stringifyWGMMATypes(typeD); } - if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg || - getScaleB() == NVVM::WGMMAScaleIn::neg)) { + if (typeD == WGMMATypes::s32 && + (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) { return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; } - mlir::NVVM::WGMMATypes typeA = getTypeA(); - mlir::NVVM::WGMMATypes typeB = getTypeB(); - if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) { - return emitOpError() << outputType + if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) { + return emitOpError() << NVVM::stringifyWGMMATypes(typeD) << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " << NVVM::stringifyWGMMATypes(typeB) << ", it is not supported."; @@ -866,8 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { } // Check transpose (only available for f16/bf16) - if ((typeA != mlir::NVVM::WGMMATypes::f16 && - typeA != mlir::NVVM::WGMMATypes::bf16) && + if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) && (getLayoutA() == mlir::NVVM::MMALayout::col || getLayoutB() == mlir::NVVM::MMALayout::col)) { return emitOpError() @@ -876,29 +889,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { << " for input types " << stringifyWGMMATypes(typeA) << " and " << stringifyWGMMATypes(typeB) << " requires transpose. However, this is only supported for: " - << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and " - << stringifyMMATypes(mlir::NVVM::MMATypes::bf16); + << stringifyMMATypes(MMATypes::f16) << " and " + << stringifyMMATypes(MMATypes::bf16); } // Check result registers - int expectedOutput; - if (outputType.isF32() || outputType.isInteger(32)) + int expectedOutput = 0; + if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32) expectedOutput = getShape().getN() / 2; - if (outputType.isF16()) + if (typeD == WGMMATypes::f16) expectedOutput = getShape().getN() / 4; if (outputSize != expectedOutput) { return emitOpError() << "results " << expectedOutput << ", however output struct has " << outputSize << " elements"; } - // Check satfinite (only availalbe for s32 accumulator) - if (!outputType.isInteger(32) && + // Check satfinite (only available for s32 accumulator) + if (typeD != WGMMATypes::s32 && getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == NVVM::MMAIntOverflow::satfinite) { return emitOpError() << " `satfinite` can be only used with s32 accumulator, however " "the current accumulator is " - << outputType; + << NVVM::stringifyWGMMATypes(typeD); } return success(); @@ -907,27 +920,15 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { std::string NVVM::WgmmaMmaAsyncOp::getPtx() { int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); - bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 || - getTypeA() == mlir::NVVM::WGMMATypes::bf16; + bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; - Value outValue = getResults() ? getResults() : getInouts(); - auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); - Type outputType = stype.getBody().front(); - std::string outputTypeName; - if (outputType.isF16()) - outputTypeName = "f16"; - else if (outputType.isF32()) - outputTypeName = "f32"; - else if (outputType.isInteger(32)) - outputTypeName = "s32"; - else - assert(false && "unsupported output type"); + StringRef outputTypeName = stringifyWGMMATypes(getTypeD()); - int expectedOutputRegisters; - if (outputType.isF32() || outputType.isInteger(32)) - expectedOutputRegisters = getShape().getN() / 2; - if (outputType.isF16()) + int expectedOutputRegisters = 0; + if (getTypeD() == WGMMATypes::f16) expectedOutputRegisters = getShape().getN() / 4; + else + expectedOutputRegisters = getShape().getN() / 2; std::string ptx; llvm::raw_string_ostream ss(ptx); @@ -958,7 +959,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() { ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; - if (!outputType.isInteger(32)) { + if (getTypeD() != WGMMATypes::s32) { ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); } // Don't add transpose parameters unless needed. @@ -975,11 +976,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues) { - Value outValue = getResults() ? getResults() : getInouts(); - auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); - Type outputType = stype.getBody().front(); - bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 || - getTypeA() == mlir::NVVM::WGMMATypes::bf16; + bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; if (getResults()) asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); if (getInouts()) @@ -988,7 +985,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues( asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())), mlir::NVVM::PTXRegisterMod::Read}); - if (!outputType.isInteger(32)) { + if (getTypeD() != WGMMATypes::s32) { asmValues.push_back( {makeConstantI32(rewriter, getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index b495363..b25dd76 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -880,41 +880,41 @@ func.func @warpgroup_mma_128_128_64( // CHECK: nvvm.wgmma.fence.aligned // CHECK: %[[UD:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> -// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S2]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S2]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64 // CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64 // CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64 // CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]] : i64 -// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], <m = 64, n = 128, k = 16>, D[%[[S4]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct +// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], %[[S4]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct // CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64 // CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]] : i64 // CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64 // CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]] : i64 -// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], <m = 64, n = 128, k = 16>, D[%[[S9]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct +// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], %[[S9]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct // CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64 // CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]] : i64 // CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64 // CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64 -// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct +// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], %[[S14]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct // CHECK: %[[S3:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64 // CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64 -// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S3]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct +// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], %[[S3]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct // CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64 // CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]] : i64 // CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64 // CHECK: %[[S27:.+]] = llvm.add %[[S1]], %[[S26]] : i64 -// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], <m = 64, n = 128, k = 16>, D[%[[S23]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct +// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], %[[S23]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct // CHECK: %[[S29:.+]] = llvm.mlir.constant(516 : i32) : i64 // CHECK: %[[S30:.+]] = llvm.add %[[S0]], %[[S29]] : i64 // CHECK: %[[S31:.+]] = llvm.mlir.constant(256 : i32) : i64 // CHECK: %[[S32:.+]] = llvm.add %[[S1]], %[[S31]] : i64 -// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], <m = 64, n = 128, k = 16>, D[%[[S28]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct +// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], %[[S28]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct // CHECK: %[[S34:.+]] = llvm.mlir.constant(518 : i32) : i64 // CHECK: %[[S35:.+]] = llvm.add %[[S0]], %[[S34]] : i64 // CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64 // CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64 -// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], <m = 64, n = 128, k = 16>, D[%[[S33]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct +// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], %[[S33]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct // CHECK: %[[S40:.+]] = llvm.insertvalue %[[S19]], %[[UD]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: %[[S41:.+]] = llvm.insertvalue %[[S38]], %[[S40]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: nvvm.wgmma.commit.group.sync.aligned @@ -1299,7 +1299,7 @@ func.func @warpgroup_matrix_multiply_m128n128k64( // CHECK: nvvm.wgmma.fence.aligned // CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> -// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, <m = 64, n = 128, k = 16>, D[%[[S138]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvvm.wgmma.mma_async // CHECK: nvvm.wgmma.mma_async // CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir index 34c8de9..9ebe3a0 100644 --- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir @@ -4,9 +4,9 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{ %result = llvm.mlir.undef : !mat64f32 // expected-error @+1 {{'nvvm.wgmma.mma_async' op results 64, however output struct has 7 elements}} - %res = nvvm.wgmma.mma_async %descA, %descB, + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 128, k = 16>, - D [%result, <zero>], + D [<f32>, <zero>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] : !mat64f32 -> !mat64f32 @@ -17,10 +17,10 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{ func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> - // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is 'f32'}} - %res = nvvm.wgmma.mma_async %descA, %descB, + // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is f32}} + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 16, k = 16>, - D [%result, <zero>, <satfinite>], + D [<f32>, <zero>, <satfinite>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> @@ -33,9 +33,9 @@ func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) { func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> // expected-error @+1 {{shape 'm' must be 64}} - %res = nvvm.wgmma.mma_async %descA, %descB, + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 32, n = 16, k = 16>, - D [%result, <zero>], + D [<f32>, <zero>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> @@ -48,9 +48,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> // expected-error @+1 {{op all elements in struct must be same type but there is 'i32'}} - %res = nvvm.wgmma.mma_async %descA, %descB, + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 16, k = 16>, - D [%result, <zero>], + D [<f32>, <zero>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> @@ -63,9 +63,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> // expected-error @+1 {{op shape 'k' must be 16 for input type f16}} - %res = nvvm.wgmma.mma_async %descA, %descB, + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 16, k = 3>, - D [%result, <zero>], + D [<f32>, <zero>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> @@ -78,9 +78,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { func.func @wgmma_transpose(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> // expected-error @+1 {{op given layouts layout_a = col and layout_b = col for input types tf32 and tf32 requires transpose. However, this is only supported for: f16 and bf16}} - %res = nvvm.wgmma.mma_async %descA, %descB, + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 16, k = 8>, - D [%result, <zero>], + D [<f32>, <zero>], A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>], B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> @@ -92,10 +92,10 @@ func.func @wgmma_transpose(%descA : i64, %descB : i64) { func.func @wgmma_transpose(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16)> - // expected-error @+1 {{'nvvm.wgmma.mma_async' op 'f16' += tf32 * tf32, it is not supported.}} - %res = nvvm.wgmma.mma_async %descA, %descB, + // expected-error @+1 {{'nvvm.wgmma.mma_async' op f16 += tf32 * tf32, it is not supported.}} + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 16, k = 8>, - D [%result, <zero>], + D [<f16>, <zero>], A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>], B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>] :!llvm.struct<(f16, f16, f16, f16)> @@ -108,9 +108,9 @@ func.func @wgmma_transpose(%descA : i64, %descB : i64) { func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)> // expected-error @+1 {{input struct and result struct must be the same type}} - %res = nvvm.wgmma.mma_async %descA, %descB, + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 8, k = 16>, - D [%result, <zero>], + D [<f16>, <zero>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] : !llvm.struct<(i32, i32, i32, i32)> @@ -122,10 +122,10 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> - // expected-error @+1 {{op 'f32' += bf16 * f16, it is not supported}} - %res = nvvm.wgmma.mma_async %descA, %descB, + // expected-error @+1 {{op f32 += bf16 * f16, it is not supported}} + %res = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 8, k = 16>, - D [%result, <zero>], + D [<f32>, <zero>], A [<bf16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index a9487bd..9c7c27c 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -329,9 +329,9 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{ // CHECK-SAME: %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} %result = llvm.mlir.undef : !mat64f32 %result1 = nvvm.wgmma.mma_async - %descA, %descB, + %descA, %descB, %result, #nvvm.shape<m = 64, n = 32, k = 16>, - D [%result, #nvvm.wgmma_scale_out<zero>], + D [<f32>, #nvvm.wgmma_scale_out<zero>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] :!mat64f32 -> !mat64f32 @@ -339,9 +339,9 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{ %descAnext = arith.addi %descA, %c2 : i64 %descBnext = arith.addi %descB, %c2 : i64 %result2 = nvvm.wgmma.mma_async - %descAnext, %descBnext, + %descAnext, %descBnext, %result1, #nvvm.shape<m = 64, n = 32, k = 16>, - D [%result1, #nvvm.wgmma_scale_out<zero>], + D [<f32>, #nvvm.wgmma_scale_out<zero>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] : !mat64f32 -> !mat64f32 @@ -393,21 +393,21 @@ func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite // CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" // CHECK-SAME: %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} - %result1 = nvvm.wgmma.mma_async %descA, %descB, + %result1 = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 8, k = 32>, - D [%result, #nvvm.wgmma_scale_out<one>, <satfinite>], + D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>], A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], B [<s8>, #nvvm.wgmma_scale_in<one>, <row>] : !mat16i32 -> !mat16i32 - %result2 = nvvm.wgmma.mma_async %descA, %descB, + %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, #nvvm.shape<m = 64, n = 8, k = 32>, - D [%result1, #nvvm.wgmma_scale_out<one>, <satfinite>], + D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>], A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], B [<s8>, #nvvm.wgmma_scale_in<one>, <row>] : !mat16i32 -> !mat16i32 - %result3 = nvvm.wgmma.mma_async %descA, %descB, + %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2, #nvvm.shape<m = 64, n = 8, k = 32>, - D [%result2, #nvvm.wgmma_scale_out<one>, <satfinite>], + D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>], A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], B [<s8>, #nvvm.wgmma_scale_in<one>, <row>] : !mat16i32 -> !mat16i32 @@ -454,21 +454,21 @@ func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 { // CHECK-SAME:}\0A", // CHECK-SAME:"=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} %result = llvm.mlir.undef : !mat16i32 - %result1 = nvvm.wgmma.mma_async %descA, %descB, + %result1 = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 8, k = 32>, - D [%result, #nvvm.wgmma_scale_out<one>], + D [<s32>, #nvvm.wgmma_scale_out<one>], A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], B [<u8>, #nvvm.wgmma_scale_in<one>, <row>] : !mat16i32 -> !mat16i32 - %result2 = nvvm.wgmma.mma_async %descA, %descB, + %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, #nvvm.shape<m = 64, n = 8, k = 32>, - D [%result1, #nvvm.wgmma_scale_out<one>], + D [<s32>, #nvvm.wgmma_scale_out<one>], A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], B [<u8>, #nvvm.wgmma_scale_in<one>, <row>] : !mat16i32 -> !mat16i32 - %result3 = nvvm.wgmma.mma_async %descA, %descB, + %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2, #nvvm.shape<m = 64, n = 8, k = 32>, - D [%result2, #nvvm.wgmma_scale_out<one>], + D [<s32>, #nvvm.wgmma_scale_out<one>], A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], B [<u8>, #nvvm.wgmma_scale_in<one>, <row>] : !mat16i32 -> !mat16i32 @@ -496,15 +496,15 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 { // CHECK-SAME: setp.ne.b32 p, $66, 0; // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" %result = llvm.mlir.undef : !mat32f32 - %result1 = nvvm.wgmma.mma_async %descA, %descB, + %result1 = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 64, k = 8>, - D [%result, #nvvm.wgmma_scale_out<one>], + D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>], A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>] : !mat32f32 -> !mat32f32 - %result2 = nvvm.wgmma.mma_async %descA, %descB, + %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, #nvvm.shape<m = 64, n = 64, k = 8>, - D [%result1, #nvvm.wgmma_scale_out<one>], + D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>], A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>] : !mat32f32 -> !mat32f32 @@ -529,15 +529,15 @@ func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 { // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0; // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" %result = llvm.mlir.undef : !mat32f32 - %result1 = nvvm.wgmma.mma_async %descA, %descB, + %result1 = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 64, k = 32>, - D [%result, #nvvm.wgmma_scale_out<one>], + D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>], A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>] : !mat32f32 -> !mat32f32 - %result2 = nvvm.wgmma.mma_async %descA, %descB, + %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, #nvvm.shape<m = 64, n = 64, k = 32>, - D [%result1, #nvvm.wgmma_scale_out<one>], + D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>], A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>] : !mat32f32 -> !mat32f32 @@ -561,15 +561,15 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 { // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0; // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" %result = llvm.mlir.undef : !mat32f32 - %result1 = nvvm.wgmma.mma_async %descA, %descB, + %result1 = nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 64, k = 32>, - D [%result, #nvvm.wgmma_scale_out<one>], + D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>], A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>] : !mat32f32 -> !mat32f32 - %result2 = nvvm.wgmma.mma_async %descA, %descB, + %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, #nvvm.shape<m = 64, n = 64, k = 32>, - D [%result1, #nvvm.wgmma_scale_out<one>], + D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>], A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>] : !mat32f32 -> !mat32f32 diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py index 36aaaea..0eef97d 100644 --- a/mlir/test/python/dialects/nvvm.py +++ b/mlir/test/python/dialects/nvvm.py @@ -32,7 +32,7 @@ def testSmoke(): nvvm.CpAsyncWaitGroupOp(5) # CHECK: %0 = llvm.mlir.undef : [[MAT_T:.*]] result = llvm.UndefOp(mat64f32_t) - # CHECK: %1 = nvvm.wgmma.mma_async %arg0, %arg1, <m = 64, n = 32, k = 16>, D[%0, <zero>], A[<f16>, <neg>, <col>], B[<f16>, <neg>, <col>] : [[MAT_T]] -> [[MAT_T]] + # CHECK: %1 = nvvm.wgmma.mma_async %arg0, %arg1, %0, <m = 64, n = 32, k = 16>, D[<f32>, <zero>], A[<f16>, <neg>, <col>], B[<f16>, <neg>, <col>] : [[MAT_T]] -> [[MAT_T]] result1 = nvvm.WgmmaMmaAsyncOp( results_=mat64f32_t, inouts=result, @@ -41,6 +41,7 @@ def testSmoke(): shape=shape_attr, typeA=nvvm.WGMMATypes.f16, typeB=nvvm.WGMMATypes.f16, + typeD=nvvm.WGMMATypes.f32, scaleD=nvvm.WGMMAScaleOut.zero, scaleA=nvvm.WGMMAScaleIn.neg, scaleB=nvvm.WGMMAScaleIn.neg, |