aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2025-02-14 18:38:59 +0100
committerMatthias Springer <mspringer@nvidia.com>2025-02-14 18:43:17 +0100
commit277160aec0fed5bf5a5ae8e41a17519b29406deb (patch)
tree90af198e31d4b97ae7d94ec402705e446bc972f4
parent70b994bcfaafadd649818d2a7f90f8f5989ec6c1 (diff)
downloadllvm-users/matthias-springer/untangle_nvvm.zip
llvm-users/matthias-springer/untangle_nvvm.tar.gz
llvm-users/matthias-springer/untangle_nvvm.tar.bz2
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp23
-rw-r--r--mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir2
-rw-r--r--mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir4
-rw-r--r--mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir91
-rw-r--r--mlir/test/Dialect/GPU/dynamic-shared-memory.mlir2
5 files changed, 65 insertions, 57 deletions
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 11363a0..87982a5 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -277,10 +277,10 @@ struct AssertOpToAssertfailLowering
Block *afterBlock =
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
rewriter.setInsertionPointToEnd(beforeBlock);
- rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
- assertBlock);
+ rewriter.create<LLVM::CondBrOp>(loc, adaptor.getArg(), afterBlock,
+ assertBlock);
rewriter.setInsertionPointToEnd(assertBlock);
- rewriter.create<cf::BranchOp>(loc, afterBlock);
+ rewriter.create<LLVM::BrOp>(loc, afterBlock);
// Continue cf.assert lowering.
rewriter.setInsertionPoint(assertOp);
@@ -377,9 +377,7 @@ struct LowerGpuOpsToNVVMOpsPass
configureGpuToNVVMTypeConverter(converter);
RewritePatternSet llvmPatterns(m.getContext());
- arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
- populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
@@ -396,15 +394,12 @@ struct LowerGpuOpsToNVVMOpsPass
} // namespace
void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
- target.addIllegalOp<func::FuncOp>();
- target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
- target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
- target.addIllegalDialect<gpu::GPUDialect>();
- target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
- LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
- LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
- LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
- LLVM::SinOp, LLVM::SqrtOp>();
+ target.addLegalDialect<LLVM::LLVMDialect, NVVM::NVVMDialect>();
+ target.addIllegalOp<cf::AssertOp, LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp,
+ LLVM::Exp2Op, LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp,
+ LLVM::FMAOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
+ LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp,
+ LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
index 1a22ba6..014855b 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-func-to-llvm="index-bitwidth=32" -convert-gpu-to-nvvm="index-bitwidth=32" -convert-arith-to-llvm="index-bitwidth=32" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index de2a4ff..69ddb54 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
+// RUN: mlir-opt %s -convert-func-to-llvm -convert-gpu-to-nvvm='has-redux=1' -convert-arith-to-llvm -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-func-to-llvm -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -convert-arith-to-llvm -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
gpu.module @test_module_0 {
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index b479467..919eac1 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -1,12 +1,11 @@
-// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
-// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s
+// RUN: mlir-opt -convert-func-to-llvm --convert-gpu-to-nvvm -convert-arith-to-llvm -reconcile-unrealized-casts --split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-func-to-llvm="index-bitwidth=32" --convert-gpu-to-nvvm="index-bitwidth=32" -convert-arith-to-llvm="index-bitwidth=32" -reconcile-unrealized-casts --split-input-file %s | FileCheck --check-prefix=CHECK32 %s
gpu.module @test_module {
- // CHECK-LABEL: func @gpu_wmma_load_op() ->
- // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- // CHECK32-LABEL: func @gpu_wmma_load_op() ->
- func.func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) {
+ // CHECK-LABEL: func @gpu_wmma_load_op()
+ // CHECK32-LABEL: func @gpu_wmma_load_op()
+ func.func @gpu_wmma_load_op() {
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
@@ -21,7 +20,7 @@ gpu.module @test_module {
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: unrealized_conversion_cast %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
@@ -33,8 +32,9 @@ gpu.module @test_module {
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
+ // CHECK32: unrealized_conversion_cast %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+ "test.user"(%0) : (!gpu.mma_matrix<16x16xf16, "AOp">) -> ()
}
}
@@ -42,10 +42,9 @@ gpu.module @test_module {
gpu.module @test_module {
- // CHECK-LABEL: func @gpu_wmma_int8_load_op() ->
- // CHECK-SAME: !llvm.struct<(i32, i32)>
- // CHECK32-LABEL: func @gpu_wmma_int8_load_op() ->
- func.func @gpu_wmma_int8_load_op() -> (!gpu.mma_matrix<16x16xsi8, "AOp">) {
+ // CHECK-LABEL: func @gpu_wmma_int8_load_op()
+ // CHECK32-LABEL: func @gpu_wmma_int8_load_op()
+ func.func @gpu_wmma_int8_load_op() {
%wg = memref.alloca() {alignment = 32} : memref<32x32xi8, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
@@ -60,7 +59,7 @@ gpu.module @test_module {
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
+ // CHECK: unrealized_conversion_cast %[[FRAG]] : !llvm.struct<(i32, i32)>
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
@@ -72,8 +71,9 @@ gpu.module @test_module {
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK32-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
- return %0 : !gpu.mma_matrix<16x16xsi8, "AOp">
+ // CHECK32: unrealized_conversion_cast %[[FRAG]] : !llvm.struct<(i32, i32)>
+
+ "test.user"(%0) : (!gpu.mma_matrix<16x16xsi8, "AOp">) -> ()
}
}
@@ -81,16 +81,18 @@ gpu.module @test_module {
gpu.module @test_module {
- // CHECK-LABEL: func @gpu_wmma_store_op
- // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
- // CHECK32-LABEL: func @gpu_wmma_store_op
- // CHECK32-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
- func.func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+ // CHECK-LABEL: func @gpu_wmma_store_op()
+ // CHECK32-LABEL: func @gpu_wmma_store_op()
+ func.func @gpu_wmma_store_op() -> () {
+ // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
+ // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK: %[[D:.*]] = builtin.unrealized_conversion_cast %{{.*}} : {{.*}} to !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK32: %[[D:.*]] = builtin.unrealized_conversion_cast %{{.*}} : {{.*}} to !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ %arg0 = "test.producer"() : () -> (!gpu.mma_matrix<16x16xf16, "COp">)
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
- // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
@@ -109,7 +111,6 @@ gpu.module @test_module {
// CHECK-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK: llvm.return
- // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
@@ -135,9 +136,14 @@ gpu.module @test_module {
gpu.module @test_module {
- // CHECK-LABEL: func @gpu_wmma_mma_op
- // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
- func.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
+ // CHECK-LABEL: func @gpu_wmma_mma_op()
+ func.func @gpu_wmma_mma_op() {
+ %A = "test.producer"() : () -> (!gpu.mma_matrix<16x16xf16, "AOp">)
+ // CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !gpu.mma_matrix<16x16xf16, "AOp">
+ %B = "test.producer"() : () -> (!gpu.mma_matrix<16x16xf16, "BOp">)
+ // CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !gpu.mma_matrix<16x16xf16, "BOp">
+ %C = "test.producer"() : () -> (!gpu.mma_matrix<16x16xf16, "COp">)
+ // CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !gpu.mma_matrix<16x16xf16, "COp">
%D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -162,8 +168,8 @@ gpu.module @test_module {
// CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]]
// CHECK-SAME: {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (
// CHECK-SAME: vector<2xf16>, {{.*}}) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- // CHECK: llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- return %D : !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: unrealized_conversion_cast %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ "test.user"(%D) : (!gpu.mma_matrix<16x16xf16, "COp">) -> ()
}
}
@@ -171,9 +177,14 @@ gpu.module @test_module {
gpu.module @test_module {
- // CHECK-LABEL: func @gpu_wmma_mma_int8_op
- // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(i32, i32, i32, i32)>, %[[B:.*]]: !llvm.struct<(i32)>, %[[C:.*]]: !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>)
- func.func @gpu_wmma_mma_int8_op(%A : !gpu.mma_matrix<32x16xsi8, "AOp">, %B : !gpu.mma_matrix<16x8xsi8, "BOp">, %C : !gpu.mma_matrix<32x8xi32, "COp">) -> (!gpu.mma_matrix<32x8xi32, "COp">) {
+ // CHECK-LABEL: func @gpu_wmma_mma_int8_op()
+ func.func @gpu_wmma_mma_int8_op() {
+ %A = "test.producer"() : () -> (!gpu.mma_matrix<32x16xsi8, "AOp">)
+ // CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !gpu.mma_matrix<32x16xsi8, "AOp">
+ %B = "test.producer"() : () -> (!gpu.mma_matrix<16x8xsi8, "BOp">)
+ // CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !gpu.mma_matrix<16x8xsi8, "BOp">
+ %C = "test.producer"() : () -> (!gpu.mma_matrix<32x8xi32, "COp">)
+ // CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !gpu.mma_matrix<32x8xi32, "COp">
%D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<32x16xsi8, "AOp">, !gpu.mma_matrix<16x8xsi8, "BOp"> -> !gpu.mma_matrix<32x8xi32, "COp">
// CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(i32, i32, i32, i32)>
// CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(i32, i32, i32, i32)>
@@ -191,8 +202,8 @@ gpu.module @test_module {
// CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[B1]], %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]], %[[C8]]
// CHECK-SAME: {eltypeA = #nvvm.mma_type<s8>, eltypeB = #nvvm.mma_type<s32>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 32 : i32, n = 8 : i32} : (
// CHECK-SAME: i32, {{.*}}) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
- // CHECK: llvm.return %[[RES]] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
- return %D : !gpu.mma_matrix<32x8xi32, "COp">
+ // CHECK: unrealized_conversion_cast %[[RES]] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+ "test.user"(%D) : (!gpu.mma_matrix<32x8xi32, "COp">) -> ()
}
}
@@ -275,11 +286,11 @@ gpu.module @test_module {
// CHECK: %[[M2:.+]] = llvm.insertvalue %[[V2]], %[[M1]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[M3:.+]] = llvm.insertvalue %[[V2]], %[[M2]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[M4:.+]] = llvm.insertvalue %[[V2]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- func.func @gpu_wmma_constant_op() ->(!gpu.mma_matrix<16x16xf16, "COp">) {
+// CHECK: unrealized_conversion_cast %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ func.func @gpu_wmma_constant_op() {
%cst = arith.constant 1.0 : f16
%C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
- return %C : !gpu.mma_matrix<16x16xf16, "COp">
+ "test.user"(%C) : (!gpu.mma_matrix<16x16xf16, "COp">) -> ()
}
}
@@ -340,10 +351,12 @@ gpu.module @test_module {
// CHECK: %[[C3:.*]] = llvm.select %[[CMP7]], %[[NAN]], %[[SEL3]] : vector<2xi1>, vector<2xf16>
// CHECK: %[[M5:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: llvm.return %[[M5]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- func.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) ->(!gpu.mma_matrix<16x16xf16, "COp">) {
+// CHECK: unrealized_conversion_cast %[[M5]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ func.func @gpu_wmma_elementwise() {
+ %A = "test.producer"() : () -> (!gpu.mma_matrix<16x16xf16, "COp">)
+ %B = "test.producer"() : () -> (!gpu.mma_matrix<16x16xf16, "COp">)
%C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
%D = gpu.subgroup_mma_elementwise maxf %C, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- return %D : !gpu.mma_matrix<16x16xf16, "COp">
+ "test.user"(%D) : (!gpu.mma_matrix<16x16xf16, "COp">) -> ()
}
}
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index 810f7d2c..8a49df8 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -convert-func-to-llvm -convert-gpu-to-nvvm -convert-arith-to-llvm -cse -canonicalize | FileCheck %s
gpu.module @modules {
// CHECK: llvm.mlir.global internal @__dynamic_shmem__3() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>