aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp87
-rw-r--r--mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir31
-rw-r--r--mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f16-f16-accum.mlir239
3 files changed, 357 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 887b375..b08b105 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -122,6 +122,7 @@ private:
AffineExpr threadIDInGroup = dim % 4;
return {RowColIndexing{threadIDInGroup, groupID}};
}
+
/// From the NVIDIA doc:
/// groupID = %laneid >> 2
/// threadIDInGroup = %laneid % 4
@@ -139,6 +140,80 @@ private:
}
//===--------------------------------------------------------------------===//
+ // m16n8k16 f16 case.
+ //===--------------------------------------------------------------------===//
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ ///
+ /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
+ /// groupID + 8 Otherwise
+ ///
+ /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
+ /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
+ static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ // clang-format off
+ return {
+ RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
+ RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
+ RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
+ RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
+ };
+ // clang-format on
+ }
+
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ ///
+ /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
+ /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
+ ///
+ /// col = groupID
+ static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ // clang-format off
+ return {
+ RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
+ RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
+ RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
+ RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
+ };
+ // clang-format on
+ }
+
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ ///
+ /// row = groupID for ci where i < 2
+ /// groupID + 8 for ci where i >= 2
+ ///
+ /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
+ static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ // clang-format off
+ return {
+ RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
+ RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
+ };
+ // clang-format on
+ }
+
+ //===--------------------------------------------------------------------===//
/// Helper functions to create customizable load and stores operations. The
/// specific shapes of each MMA instruction are passed via the
/// IndexCalculator callback.
@@ -293,6 +368,7 @@ FailureOr<MmaSyncBuilder::MmaSyncInfo>
MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
TypeRange elementalTypes) {
// TODO: Tablegen all this.
+ Type f16 = b.getF16Type();
Type f32 = b.getF32Type();
if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
elementalTypes == TypeRange{f32, f32, f32}) {
@@ -303,6 +379,17 @@ MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
SmallVector<int64_t>{opShape.begin(), opShape.end()},
/*tf32Enabled=*/true};
}
+ // This is the version with f16 accumulation.
+ // TODO: version with f32 accumulation.
+ if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
+ elementalTypes == TypeRange{f16, f16, f16}) {
+ return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
+ &MmaSyncBuilder::m16n8k16f16Rhs,
+ &MmaSyncBuilder::m16n8k16f16Res),
+ makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
+ SmallVector<int64_t>{opShape.begin(), opShape.end()},
+ /*tf32Enabled=*/false};
+ }
return failure();
}
diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
index 55ff52b..241f218 100644
--- a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -80,3 +80,34 @@ transform.sequence failures(propagate) {
transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
: (!transform.any_op) -> ()
}
+
+// -----
+
+// CHECK-LABEL: func.func @matmul_16x8x16xf16_global
+func.func @matmul_16x8x16xf16_global(
+ %A: memref<16x16xf16>, %B: memref<16x8xf16>, %C: memref<16x8xf16>) {
+
+ // CHECK-COUNT-8: memref.load {{.*}} : memref<16x16xf16>
+ // CHECK-COUNT-8: vector.insert {{.*}} : f16 into vector<4x2xf16>
+ // CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf16>
+ // CHECK-COUNT-4: vector.insert {{.*}} : f16 into vector<2x2xf16>
+ // CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf16>
+ // CHECK-COUNT-4: vector.insert {{.*}} : f16 into vector<2x2xf16>
+ //
+ // CHECK: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 16]}
+ // CHECK-SAME: : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ //
+ // CHECK-COUNT-4: vector.extract %{{.*}} : vector<2x2xf16>
+ // CHECK-COUNT-4: memref.store %{{.*}} : memref<16x8xf16>
+ linalg.matmul ins(%A, %B: memref<16x16xf16>, memref<16x8xf16>)
+ outs(%C: memref<16x8xf16>)
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
+ : (!transform.any_op) -> ()
+}
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f16-f16-accum.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f16-f16-accum.mlir
new file mode 100644
index 0000000..0a99380
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f16-f16-accum.mlir
@@ -0,0 +1,239 @@
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: -test-transform-dialect-erase-schedule \
+// RUN: -gpu-kernel-outlining \
+// RUN: -convert-scf-to-cf \
+// RUN: -convert-vector-to-llvm \
+// RUN: -convert-math-to-llvm \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -convert-arith-to-llvm \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -canonicalize \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80}))' \
+// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -gpu-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+!lhs_memref_type = memref<16x16xf16>
+!rhs_memref_type = memref<16x8xf16>
+!res_memref_type = memref<16x8xf16>
+
+func.func @compute_linspace_val(%ridx: index, %cidx: index, %strideCidx: index) -> f16 {
+ %r = arith.index_cast %ridx : index to i32
+ %c = arith.index_cast %cidx : index to i32
+ %strideC = arith.index_cast %strideCidx : index to i32
+ %2 = arith.muli %r, %strideC : i32
+ %3 = arith.addi %c, %2 : i32
+ %4 = arith.sitofp %3 : i32 to f16
+ %factor = arith.constant 64.0 : f16
+ %5 = arith.divf %4, %factor : f16
+ return %5: f16
+}
+
+func.func @print_lhs_as_memref_32(%lhs: !lhs_memref_type) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %lhs, %c0 : !lhs_memref_type
+ %N = memref.dim %lhs, %c1 : !lhs_memref_type
+ %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+ scf.for %m = %c0 to %M step %c1 {
+ scf.for %n = %c0 to %N step %c1 {
+ %f16 = memref.load %lhs[%m, %n] : !lhs_memref_type
+ %f32 = arith.extf %f16 : f16 to f32
+ memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+ }
+ }
+ %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+ memref.dealloc %tmp_alloc : memref<?x?xf32>
+ return
+}
+
+func.func @print_rhs_as_memref_32(%rhs: !rhs_memref_type) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %rhs, %c0 : !rhs_memref_type
+ %N = memref.dim %rhs, %c1 : !rhs_memref_type
+ %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+ scf.for %m = %c0 to %M step %c1 {
+ scf.for %n = %c0 to %N step %c1 {
+ %f16 = memref.load %rhs[%m, %n] : !rhs_memref_type
+ %f32 = arith.extf %f16 : f16 to f32
+ memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+ }
+ }
+ %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+ memref.dealloc %tmp_alloc : memref<?x?xf32>
+ return
+}
+
+func.func @print_res_as_memref_32(%res: !res_memref_type) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %res, %c0 : !res_memref_type
+ %N = memref.dim %res, %c1 : !res_memref_type
+ %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+ scf.for %m = %c0 to %M step %c1 {
+ scf.for %n = %c0 to %N step %c1 {
+ %f16 = memref.load %res[%m, %n] : !res_memref_type
+ %f32 = arith.extf %f16 : f16 to f32
+ memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+ }
+ }
+ %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+ memref.dealloc %tmp_alloc : memref<?x?xf32>
+ return
+}
+
+func.func @main() {
+ %lhs = memref.alloc() : !lhs_memref_type
+ %rhs = memref.alloc() : !rhs_memref_type
+ %res = memref.alloc() : !res_memref_type
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %res, %c0 : !res_memref_type
+ %N = memref.dim %res, %c1 : !res_memref_type
+ %K = memref.dim %lhs, %c1 : !lhs_memref_type
+
+ %f1 = arith.constant 1.0e+00 : f16
+ %f0 = arith.constant 0.0e+00 : f16
+ %c32 = arith.constant 32 : index
+
+ // Intialize the lhs matrix with a linspace function.
+ scf.for %r = %c0 to %M step %c1 {
+ scf.for %c = %c0 to %K step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %K) : (index, index, index) -> f16
+ memref.store %idx, %lhs[%r, %c] : !lhs_memref_type
+ }
+ }
+ // Intialize the rhs matrix with a linspace function.
+ scf.for %r = %c0 to %K step %c1 {
+ scf.for %c = %c0 to %N step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f16
+ memref.store %idx, %rhs[%r, %c] : !rhs_memref_type
+ }
+ }
+ // Intialize the rhs matrix with a linspace function.
+ scf.for %r = %c0 to %M step %c1 {
+ scf.for %c = %c0 to %N step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f16
+ memref.store %idx, %res[%r, %c] : !res_memref_type
+ }
+ }
+
+ %ulhs = memref.cast %lhs : !lhs_memref_type to memref<*xf16>
+ %urhs = memref.cast %rhs : !rhs_memref_type to memref<*xf16>
+ %ures = memref.cast %res : !res_memref_type to memref<*xf16>
+ gpu.host_register %ulhs : memref<*xf16>
+ gpu.host_register %urhs : memref<*xf16>
+ gpu.host_register %ures : memref<*xf16>
+
+ // Print the memrefs before computation.
+ call @print_lhs_as_memref_32(%lhs) : (!lhs_memref_type) -> ()
+ // CHECK: [0, 0.015625, 0.03125, 0.046875, 0.0625, 0.078125, 0.09375, 0.109375, 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375],
+ // CHECK: [0.25, 0.265625, 0.28125, 0.296875, 0.3125, 0.328125, 0.34375, 0.359375, 0.375, 0.390625, 0.40625, 0.421875, 0.4375, 0.453125, 0.46875, 0.484375],
+ // CHECK: [0.5, 0.515625, 0.53125, 0.546875, 0.5625, 0.578125, 0.59375, 0.609375, 0.625, 0.640625, 0.65625, 0.671875, 0.6875, 0.703125, 0.71875, 0.734375],
+ // CHECK: [0.75, 0.765625, 0.78125, 0.796875, 0.8125, 0.828125, 0.84375, 0.859375, 0.875, 0.890625, 0.90625, 0.921875, 0.9375, 0.953125, 0.96875, 0.984375],
+ // CHECK: [1, 1.01562, 1.03125, 1.04688, 1.0625, 1.07812, 1.09375, 1.10938, 1.125, 1.14062, 1.15625, 1.17188, 1.1875, 1.20312, 1.21875, 1.23438],
+ // CHECK: [1.25, 1.26562, 1.28125, 1.29688, 1.3125, 1.32812, 1.34375, 1.35938, 1.375, 1.39062, 1.40625, 1.42188, 1.4375, 1.45312, 1.46875, 1.48438],
+ // CHECK: [1.5, 1.51562, 1.53125, 1.54688, 1.5625, 1.57812, 1.59375, 1.60938, 1.625, 1.64062, 1.65625, 1.67188, 1.6875, 1.70312, 1.71875, 1.73438],
+ // CHECK: [1.75, 1.76562, 1.78125, 1.79688, 1.8125, 1.82812, 1.84375, 1.85938, 1.875, 1.89062, 1.90625, 1.92188, 1.9375, 1.95312, 1.96875, 1.98438],
+ // CHECK: [2, 2.01562, 2.03125, 2.04688, 2.0625, 2.07812, 2.09375, 2.10938, 2.125, 2.14062, 2.15625, 2.17188, 2.1875, 2.20312, 2.21875, 2.23438],
+ // CHECK: [2.25, 2.26562, 2.28125, 2.29688, 2.3125, 2.32812, 2.34375, 2.35938, 2.375, 2.39062, 2.40625, 2.42188, 2.4375, 2.45312, 2.46875, 2.48438],
+ // CHECK: [2.5, 2.51562, 2.53125, 2.54688, 2.5625, 2.57812, 2.59375, 2.60938, 2.625, 2.64062, 2.65625, 2.67188, 2.6875, 2.70312, 2.71875, 2.73438],
+ // CHECK: [2.75, 2.76562, 2.78125, 2.79688, 2.8125, 2.82812, 2.84375, 2.85938, 2.875, 2.89062, 2.90625, 2.92188, 2.9375, 2.95312, 2.96875, 2.98438],
+ // CHECK: [3, 3.01562, 3.03125, 3.04688, 3.0625, 3.07812, 3.09375, 3.10938, 3.125, 3.14062, 3.15625, 3.17188, 3.1875, 3.20312, 3.21875, 3.23438],
+ // CHECK: [3.25, 3.26562, 3.28125, 3.29688, 3.3125, 3.32812, 3.34375, 3.35938, 3.375, 3.39062, 3.40625, 3.42188, 3.4375, 3.45312, 3.46875, 3.48438],
+ // CHECK: [3.5, 3.51562, 3.53125, 3.54688, 3.5625, 3.57812, 3.59375, 3.60938, 3.625, 3.64062, 3.65625, 3.67188, 3.6875, 3.70312, 3.71875, 3.73438],
+ // CHECK: [3.75, 3.76562, 3.78125, 3.79688, 3.8125, 3.82812, 3.84375, 3.85938, 3.875, 3.89062, 3.90625, 3.92188, 3.9375, 3.95312, 3.96875, 3.98438]
+
+ call @print_rhs_as_memref_32(%rhs) : (!rhs_memref_type) -> ()
+ // CHECK: [0, 0.015625, 0.03125, 0.046875, 0.0625, 0.078125, 0.09375, 0.109375],
+ // CHECK: [0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375],
+ // CHECK: [0.25, 0.265625, 0.28125, 0.296875, 0.3125, 0.328125, 0.34375, 0.359375],
+ // CHECK: [0.375, 0.390625, 0.40625, 0.421875, 0.4375, 0.453125, 0.46875, 0.484375],
+ // CHECK: [0.5, 0.515625, 0.53125, 0.546875, 0.5625, 0.578125, 0.59375, 0.609375],
+ // CHECK: [0.625, 0.640625, 0.65625, 0.671875, 0.6875, 0.703125, 0.71875, 0.734375],
+ // CHECK: [0.75, 0.765625, 0.78125, 0.796875, 0.8125, 0.828125, 0.84375, 0.859375],
+ // CHECK: [0.875, 0.890625, 0.90625, 0.921875, 0.9375, 0.953125, 0.96875, 0.984375],
+ // CHECK: [1, 1.01562, 1.03125, 1.04688, 1.0625, 1.07812, 1.09375, 1.10938],
+ // CHECK: [1.125, 1.14062, 1.15625, 1.17188, 1.1875, 1.20312, 1.21875, 1.23438],
+ // CHECK: [1.25, 1.26562, 1.28125, 1.29688, 1.3125, 1.32812, 1.34375, 1.35938],
+ // CHECK: [1.375, 1.39062, 1.40625, 1.42188, 1.4375, 1.45312, 1.46875, 1.48438],
+ // CHECK: [1.5, 1.51562, 1.53125, 1.54688, 1.5625, 1.57812, 1.59375, 1.60938],
+ // CHECK: [1.625, 1.64062, 1.65625, 1.67188, 1.6875, 1.70312, 1.71875, 1.73438],
+ // CHECK: [1.75, 1.76562, 1.78125, 1.79688, 1.8125, 1.82812, 1.84375, 1.85938],
+ // CHECK: [1.875, 1.89062, 1.90625, 1.92188, 1.9375, 1.95312, 1.96875, 1.98438]
+
+ call @print_res_as_memref_32(%res) : (!res_memref_type) -> ()
+ // CHECK: [0, 0.015625, 0.03125, 0.046875, 0.0625, 0.078125, 0.09375, 0.109375],
+ // CHECK: [0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375],
+ // CHECK: [0.25, 0.265625, 0.28125, 0.296875, 0.3125, 0.328125, 0.34375, 0.359375],
+ // CHECK: [0.375, 0.390625, 0.40625, 0.421875, 0.4375, 0.453125, 0.46875, 0.484375],
+ // CHECK: [0.5, 0.515625, 0.53125, 0.546875, 0.5625, 0.578125, 0.59375, 0.609375],
+ // CHECK: [0.625, 0.640625, 0.65625, 0.671875, 0.6875, 0.703125, 0.71875, 0.734375],
+ // CHECK: [0.75, 0.765625, 0.78125, 0.796875, 0.8125, 0.828125, 0.84375, 0.859375],
+ // CHECK: [0.875, 0.890625, 0.90625, 0.921875, 0.9375, 0.953125, 0.96875, 0.984375],
+ // CHECK: [1, 1.01562, 1.03125, 1.04688, 1.0625, 1.07812, 1.09375, 1.10938],
+ // CHECK: [1.125, 1.14062, 1.15625, 1.17188, 1.1875, 1.20312, 1.21875, 1.23438],
+ // CHECK: [1.25, 1.26562, 1.28125, 1.29688, 1.3125, 1.32812, 1.34375, 1.35938],
+ // CHECK: [1.375, 1.39062, 1.40625, 1.42188, 1.4375, 1.45312, 1.46875, 1.48438],
+ // CHECK: [1.5, 1.51562, 1.53125, 1.54688, 1.5625, 1.57812, 1.59375, 1.60938],
+ // CHECK: [1.625, 1.64062, 1.65625, 1.67188, 1.6875, 1.70312, 1.71875, 1.73438],
+ // CHECK: [1.75, 1.76562, 1.78125, 1.79688, 1.8125, 1.82812, 1.84375, 1.85938],
+ // CHECK: [1.875, 1.89062, 1.90625, 1.92188, 1.9375, 1.95312, 1.96875, 1.98438]
+
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+
+ linalg.matmul ins(%lhs, %rhs: !lhs_memref_type, !rhs_memref_type)
+ outs(%res: !res_memref_type)
+
+ gpu.terminator
+ }
+
+
+ // Print the result memref after computation.
+ // This has been verified against other f16 CUDA implementations.
+ call @print_res_as_memref_32(%res) : (!res_memref_type) -> ()
+ // CHECK: [2.42188, 2.4668, 2.51172, 2.55664, 2.60156, 2.64648, 2.69141, 2.73633],
+ // CHECK: [6.29688, 6.40625, 6.51172, 6.61719, 6.72656, 6.83594, 6.94141, 7.04688],
+ // CHECK: [10.1719, 10.3438, 10.5156, 10.6797, 10.8516, 11.0234, 11.1875, 11.3594],
+ // CHECK: [14.0469, 14.2812, 14.5156, 14.7422, 14.9766, 15.2109, 15.4375, 15.6719],
+ // CHECK: [17.9219, 18.2188, 18.5156, 18.8125, 19.0938, 19.3906, 19.6875, 19.9844],
+ // CHECK: [21.7969, 22.1562, 22.5156, 22.875, 23.2188, 23.5781, 23.9375, 24.2969],
+ // CHECK: [25.6719, 26.0938, 26.5156, 26.9375, 27.3438, 27.7656, 28.1875, 28.6094],
+ // CHECK: [29.5469, 30.0312, 30.5156, 31, 31.4688, 31.9531, 32.4375, 32.9375],
+ // CHECK: [33.4375, 33.9688, 34.5, 35.0625, 35.5938, 36.1562, 36.6875, 37.25],
+ // CHECK: [37.3125, 37.9062, 38.5, 39.125, 39.7188, 40.3438, 40.9375, 41.5625],
+ // CHECK: [41.1875, 41.8438, 42.5, 43.1875, 43.8438, 44.5312, 45.1875, 45.875],
+ // CHECK: [45.0625, 45.7812, 46.5, 47.25, 47.9688, 48.7188, 49.4375, 50.1875],
+ // CHECK: [48.9375, 49.7188, 50.5, 51.3125, 52.0938, 52.9062, 53.6875, 54.5],
+ // CHECK: [52.8125, 53.6562, 54.5, 55.375, 56.2188, 57.0938, 57.9375, 58.8125],
+ // CHECK: [56.6875, 57.5938, 58.5, 59.4375, 60.3438, 61.2812, 62.1875, 63.125],
+ // CHECK: [60.5625, 61.5312, 62.5, 63.5, 64.5, 65.4375, 66.4375, 67.4375]
+
+ return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
+ : (!transform.any_op) -> ()
+}