diff options
Diffstat (limited to 'mlir/test/Dialect')
-rw-r--r-- | mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir | 190 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir | 15 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/transform-op-match.mlir | 5 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/transform-op-pad.mlir | 3 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/invalid.mlir | 96 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/ops.mlir | 49 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/resharding-spmdization.mlir | 154 | ||||
-rw-r--r-- | mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (renamed from mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir) | 28 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir | 72 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/ops-invalid.mlir | 8 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/test-interpreter.mlir | 96 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/test-loop-transforms.mlir | 9 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | 15 |
13 files changed, 650 insertions, 90 deletions
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir index b714607..f04a01f 100644 --- a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir +++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir @@ -1,71 +1,191 @@ -// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s +// RUN: mlir-opt --allow-unregistered-dialect \ +// RUN: --test-gpu-subgroup-reduce-lowering %s \ +// RUN: | FileCheck %s --check-prefix=CHECK-SUB -// CHECK: gpu.module @kernels { +// RUN: mlir-opt --allow-unregistered-dialect \ +// RUN: --test-gpu-subgroup-reduce-lowering="expand-to-shuffles" %s \ +// RUN: | FileCheck %s --check-prefix=CHECK-SHFL + +// CHECK-SUB: gpu.module @kernels { +// CHECK-SHFL: gpu.module @kernels { gpu.module @kernels { - // CHECK-LABEL: gpu.func @kernel0( - // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>) + // CHECK-SUB-LABEL: gpu.func @kernel0( + // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<5xf16>) + // + // CHECK-SHFL-LABEL: gpu.func @kernel0( gpu.func @kernel0(%arg0: vector<5xf16>) kernel { - // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16> - // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16> - // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16> - // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16> - // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16> - // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16> - // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16> - // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16> - // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16 - // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16> - // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> () + // CHECK-SUB: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16> + // CHECK-SUB: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16> + // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16> + // CHECK-SUB: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16> + // CHECK-SUB: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16> + // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16> + // CHECK-SUB: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16> + // CHECK-SUB: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16> + // CHECK-SUB: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16 + // CHECK-SUB: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16> + // CHECK-SUB: "test.consume"(%[[V2]]) : (vector<5xf16>) -> () %sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>) "test.consume"(%sum0) : (vector<5xf16>) -> () - - // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform - // CHECK: "test.consume" + // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform + // CHECK-SUB: "test.consume" %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>) "test.consume"(%sum1) : (vector<5xf16>) -> () - // CHECK: gpu.return + // CHECK-SUB: gpu.return gpu.return } - // CHECK-LABEL: gpu.func @kernel1( - // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>) + // CHECK-SUB-LABEL: gpu.func @kernel1( + // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<1xf32>) + // + // CHECK-SHFL-LABEL: gpu.func @kernel1( gpu.func @kernel1(%arg0: vector<1xf32>) kernel { - // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32> - // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32 - // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32> - // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> () + // CHECK-SUB: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32> + // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32 + // CHECK-SUB: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32> + // CHECK-SUB: "test.consume"(%[[V0]]) : (vector<1xf32>) -> () %sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>) "test.consume"(%sum0) : (vector<1xf32>) -> () - // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32 - // CHECK: "test.consume" + // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32 + // CHECK-SUB: "test.consume" %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>) "test.consume"(%sum1) : (vector<1xf32>) -> () - // CHECK: gpu.return + // CHECK-SUB: gpu.return gpu.return } // These vectors fit the native shuffle size and should not be broken down. // - // CHECK-LABEL: gpu.func @kernel2( - // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>) + // CHECK-SUB-LABEL: gpu.func @kernel2( + // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>) + // + // CHECK-SHFL-LABEL: gpu.func @kernel2( gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel { - // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8> - // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> () + // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8> + // CHECK-SUB: "test.consume"(%[[R0]]) : (vector<3xi8>) -> () %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>) "test.consume"(%sum0) : (vector<3xi8>) -> () - // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8> - // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> () + // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8> + // CHECK-SUB: "test.consume"(%[[R1]]) : (vector<4xi8>) -> () %sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>) "test.consume"(%sum1) : (vector<4xi8>) -> () - // CHECK: gpu.return + // CHECK-SUB: gpu.return + gpu.return + } + + // CHECK-SHFL-LABEL: gpu.func @kernel3( + // CHECK-SHFL-SAME: %[[ARG0:.+]]: i32) + gpu.func @kernel3(%arg0: i32) kernel { + // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32 + // CHECK-SHFL-DAG: %[[C8:.+]] = arith.constant 8 : i32 + // CHECK-SHFL-DAG: %[[C16:.+]] = arith.constant 16 : i32 + // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32 + + // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[ARG0]], %[[C1]], %[[C32]] : i32 + // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[S0]] : i32 + // CHECK-SHFL: %[[S1:.+]], %{{.+}} = gpu.shuffle xor %[[A0]], %[[C2]], %[[C32]] : i32 + // CHECK-SHFL: %[[A1:.+]] = arith.addi %[[A0]], %[[S1]] : i32 + // CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32 + // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32 + // CHECK-SHFL: %[[S3:.+]], %{{.+}} = gpu.shuffle xor %[[A2]], %[[C8]], %[[C32]] : i32 + // CHECK-SHFL: %[[A3:.+]] = arith.addi %[[A2]], %[[S3]] : i32 + // CHECK-SHFL: %[[S4:.+]], %{{.+}} = gpu.shuffle xor %[[A3]], %[[C16]], %[[C32]] : i32 + // CHECK-SHFL: %[[A4:.+]] = arith.addi %[[A3]], %[[S4]] : i32 + // CHECK-SHFL: "test.consume"(%[[A4]]) : (i32) -> () + %sum0 = gpu.subgroup_reduce add %arg0 : (i32) -> i32 + "test.consume"(%sum0) : (i32) -> () + + // CHECK-SHFL: gpu.return + gpu.return + } + + // CHECK-SHFL-LABEL: gpu.func @kernel4( + // CHECK-SHFL-SAME: %[[ARG0:.+]]: vector<2xf16>) + gpu.func @kernel4(%arg0: vector<2xf16>) kernel { + // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32 + // CHECK-SHFL-DAG: %[[C8:.+]] = arith.constant 8 : i32 + // CHECK-SHFL-DAG: %[[C16:.+]] = arith.constant 16 : i32 + // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32 + + // CHECK-SHFL: %[[V0:.+]] = vector.bitcast %[[ARG0]] : vector<2xf16> to vector<1xi32> + // CHECK-SHFL: %[[I0:.+]] = vector.extract %[[V0]][0] : i32 from vector<1xi32> + // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[I0]], %[[C1]], %[[C32]] : i32 + // CHECK-SHFL: %[[BR0:.+]] = vector.broadcast %[[S0]] : i32 to vector<1xi32> + // CHECK-SHFL: %[[BC0:.+]] = vector.bitcast %[[BR0]] : vector<1xi32> to vector<2xf16> + // CHECK-SHFL: %[[ADD0:.+]] = arith.addf %[[ARG0]], %[[BC0]] : vector<2xf16> + // CHECK-SHFL: %[[BC1:.+]] = vector.bitcast %[[ADD0]] : vector<2xf16> to vector<1xi32> + // CHECK-SHFL: %[[I1:.+]] = vector.extract %[[BC1]][0] : i32 from vector<1xi32> + // CHECK-SHFL: gpu.shuffle xor %[[I1]], %[[C2]], %[[C32]] : i32 + // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16> + // CHECK-SHFL: gpu.shuffle xor %{{.+}}, %[[C4]], %[[C32]] : i32 + // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16> + // CHECK-SHFL: gpu.shuffle xor %{{.+}}, %[[C8]], %[[C32]] : i32 + // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16> + // CHECK-SHFL: %[[SL:.+]], %{{.+}} = gpu.shuffle xor %{{.+}}, %[[C16]], %[[C32]] : i32 + // CHECK-SHFL: %[[BRL:.+]] = vector.broadcast %[[SL]] : i32 to vector<1xi32> + // CHECK-SHFL: %[[BCL:.+]] = vector.bitcast %[[BRL]] : vector<1xi32> to vector<2xf16> + // CHECK-SHFL: %[[ADDL:.+]] = arith.addf %{{.+}}, %[[BCL]] : vector<2xf16> + // CHECK-SHFL: "test.consume"(%[[ADDL]]) : (vector<2xf16>) -> () + %sum0 = gpu.subgroup_reduce add %arg0 : (vector<2xf16>) -> (vector<2xf16>) + "test.consume"(%sum0) : (vector<2xf16>) -> () + + // CHECK-SHFL: gpu.return + gpu.return + } + + // CHECK-SHFL-LABEL: gpu.func @kernel5( + // CHECK-SHFL-SAME: %[[ARG0:.+]]: i16) + gpu.func @kernel5(%arg0: i16) kernel { + // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32 + // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32 + // CHECK-SHFL: %[[T0:.+]] = arith.trunci %[[S0]] : i32 to i16 + // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[T0]] : i16 + // CHECK-SHFL: %[[E1:.+]] = arith.extui %[[A0]] : i16 to i32 + // CHECK-SHFL: %{{.+}}, %{{.+}} = gpu.shuffle xor %[[E1]], {{.+}} : i32 + // CHECK-SHFL-COUNT-3: gpu.shuffle xor + // CHECK-SHFL: arith.trunci {{.+}} : i32 to i16 + // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16 + // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> () + %sum0 = gpu.subgroup_reduce add %arg0 : (i16) -> i16 + "test.consume"(%sum0) : (i16) -> () + + // CHECK-SHFL: gpu.return + gpu.return + } + + // CHECK-SHFL-LABEL: gpu.func @kernel6( + // CHECK-SHFL-SAME: %[[ARG0:.+]]: vector<3xi8>) + gpu.func @kernel6(%arg0: vector<3xi8>) kernel { + // CHECK-SHFL: %[[CZ:.+]] = arith.constant dense<0> : vector<4xi8> + // CHECK-SHFL: %[[V0:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CZ]] {offsets = [0], strides = [1]} : vector<3xi8> into vector<4xi8> + // CHECK-SHFL: %[[BC0:.+]] = vector.bitcast %[[V0]] : vector<4xi8> to vector<1xi32> + // CHECK-SHFL: %[[I0:.+]] = vector.extract %[[BC0]][0] : i32 from vector<1xi32> + // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[I0]], {{.+}} : i32 + // CHECK-SHFL: %[[BR0:.+]] = vector.broadcast %[[S0]] : i32 to vector<1xi32> + // CHECK-SHFL: %[[BC1:.+]] = vector.bitcast %[[BR0]] : vector<1xi32> to vector<4xi8> + // CHECK-SHFL: %[[ADD0:.+]] = arith.addi %[[V0]], %[[BC1]] : vector<4xi8> + // CHECK-SHFL: %[[BC2:.+]] = vector.bitcast %[[ADD0]] : vector<4xi8> to vector<1xi32> + // CHECK-SHFL: %[[I1:.+]] = vector.extract %[[BC2]][0] : i32 from vector<1xi32> + // CHECK-SHFL-COUNT-4: gpu.shuffle xor + // CHECK-SHFL: %[[ESS:.+]] = vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [3], strides = [1]} : vector<4xi8> to vector<3xi8> + // CHECK-SHFL: "test.consume"(%[[ESS]]) : (vector<3xi8>) -> () + %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>) + "test.consume"(%sum0) : (vector<3xi8>) -> () + + // CHECK-SHFL: gpu.return gpu.return } } + diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir index 49a52ba..aa15ccf 100644 --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} { // Ensure that one linalg.fill was generated. %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op + %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op + transform.test_print_param %p : !transform.param<i64> // Ensure that one linalg.copy was generated. %mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op + %p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op + transform.test_print_param %p2 : !transform.param<i64> transform.yield } } @@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} { // Ensure that one linalg.fill was generated. %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op + %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op + transform.test_print_param %p : !transform.param<i64> // Ensure that one linalg.copy was generated. %linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op + %p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op + transform.test_print_param %p2 : !transform.param<i64> // Ensure that one memref.alloca was generated. %alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op + %p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op + transform.test_print_param %p3 : !transform.param<i64> // Make sure that One-Shot Bufferize can bufferize the rest. %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir index 15942db..db5b5f1 100644 --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} { #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-remark @below {{0}} - transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op + %p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64> + // expected-remark @below {{0}} + transform.test_print_param %p : !transform.param<i64> transform.yield } } diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir index 6bca6c1..1f9d81a 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} { padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">) + %p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64> // expected-remark @below {{1}} - transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination"> + transform.test_print_param %p : !transform.param<i64> transform.yield } } diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir index 03994f8..3ee578a 100644 --- a/mlir/test/Dialect/Mesh/invalid.mlir +++ b/mlir/test/Dialect/Mesh/invalid.mlir @@ -70,6 +70,102 @@ func.func @mesh_axis_negtive_in_partial( // ----- +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) + +func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) { + // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} + %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index + return %0#0, %0#1 : index, index +} + +// ----- + +mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3) + +func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) { + // expected-error@+1 {{Mesh axes contains duplicate elements.}} + %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// ----- + +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) + +func.func @cluster_shape_wrong_number_of_results() -> (index, index) { + // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} + %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index + return %0#0, %0#1 : index, index +} + +// ----- + +mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3) + +func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) { + // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} + %0:2 = mesh.cluster_shape @mesh0 : index, index + return %0#0, %0#1 : index, index +} + +// ----- + +func.func @cluster_shape_invalid_mesh_name() -> (index) { + // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} + %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index + return %0#0 : index +} + +// ----- + +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) + +func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) { + // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} + %0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index + return %0#0, %0#1 : index, index +} + +// ----- + +mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3) + +func.func @process_index_duplicate_mesh_axis() -> (index, index, index) { + // expected-error@+1 {{Mesh axes contains duplicate elements.}} + %0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// ----- + +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) + +func.func @process_index_wrong_number_of_results() -> (index, index) { + // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} + %0:2 = mesh.process_index on @mesh0 axes = [0] : index, index + return %0#0, %0#1 : index, index +} + +// ----- + +mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3) + +func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) { + // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} + %0:2 = mesh.process_index on @mesh0 : index, index + return %0#0, %0#1 : index, index +} + +// ----- + +func.func @process_index_invalid_mesh_name() -> (index) { + // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} + %0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index + return %0#0 : index +} + +// ----- + func.func @all_reduce_invalid_mesh_symbol( %arg0 : tensor<4xf32>) -> tensor<4xf64> { // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index 8f8e309..a7c3b3d 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -132,6 +132,55 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32> } +// CHECK-LABEL: func @cluster_shape +func.func @cluster_shape() -> (index, index) { + // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index + %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index + // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func @cluster_shape_default_axes +func.func @cluster_shape_default_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index + %0:3 = mesh.cluster_shape @mesh0 : index, index, index + // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// CHECK-LABEL: func @cluster_shape_empty_axes +func.func @cluster_shape_empty_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index + %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index + // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// CHECK-LABEL: func @process_index +func.func @process_index() -> (index, index) { + // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index + %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index + // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func @process_index_default_axes +func.func @process_index_default_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index + %0:3 = mesh.process_index on @mesh0 : index, index, index + // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// CHECK-LABEL: func @process_index_empty_axes +func.func @process_index_empty_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index + %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index + // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + + // CHECK-LABEL: func @all_reduce func.func @all_reduce( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir new file mode 100644 index 0000000..0ba0d76 --- /dev/null +++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir @@ -0,0 +1,154 @@ +// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s + +mesh.cluster @mesh_1d(rank = 1, dim_sizes = 2) +mesh.cluster @mesh_1d_dynamic(rank = 1, dim_sizes = ?) + +// CHECK-LABEL: func @same_source_and_target_sharding +func.func @same_source_and_target_sharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> + %arg0: tensor<2xf32> +) -> tensor<2xf32> { + %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32> + %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32> + // CHECK: return %[[ARG]] + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: func @split_replicated_tensor_axis +func.func @split_replicated_tensor_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32> + %arg0: tensor<3x14xf32> +) -> tensor<3x14xf32> { + // CHECK: %[[ZERO:.*]] = arith.constant 0 : index + // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index + // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index + // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index + // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index + // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index + // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]] + // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index + // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index + // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32> + // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32> + %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32> + // CHECK: return %[[RESULT]] : tensor<3x14xf32> + return %1 : tensor<3x14xf32> +} + +// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic +func.func @split_replicated_tensor_axis_dynamic( + // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32> + %arg0: tensor<?x3x?xf32> +) -> tensor<?x3x?xf32> { + // CHECK: %[[ZERO:.*]] = arith.constant 0 : index + // CHECK: %[[TWO:.*]] = arith.constant 2 : index + // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index + // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index + // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32> + // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index + // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index + // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]] + // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index + // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index + // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32> + // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0] + // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32> + %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32> + %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32> + // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32> + return %1 : tensor<?x3x?xf32> +} + +// CHECK-LABEL: func @move_split_axis +func.func @move_split_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> + // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32> + %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @move_split_axis_dynamic_mesh +func.func @move_split_axis_dynamic_mesh( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32> + // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32> + // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32> + %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @move_split_dynamic_axis +func.func @move_split_dynamic_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> + %arg0: tensor<?x14xf32> +) -> tensor<?x14xf32> { + // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32> + %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32> + // CHECK: return %[[RES]] : tensor<?x14xf32> + return %1 : tensor<?x14xf32> +} + +// CHECK-LABEL: func @unshard_static_axis +func.func @unshard_static_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> + // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32> + %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @unshard_dynamic_axis +func.func @unshard_dynamic_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> + %arg0: tensor<?x14xf32> +) -> tensor<?x14xf32> { + // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32> + %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32> + return %1 : tensor<?x14xf32> +} + +// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis +func.func @unshard_static_axis_on_dynamic_mesh_axis( +// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32> + // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32> + // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32> + %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @partial_axis +func.func @partial_axis( +// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32> + %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32> + // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir index f584977e..6fe7ec9 100644 --- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir @@ -1,5 +1,13 @@ // RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-gpu-codegen="num-threads=0" | FileCheck %s +#NV_24 = #sparse_tensor.encoding<{ + map = ( i, j ) -> + ( i : dense, + j floordiv 4 : dense, + j mod 4 : block2_4 + ) +}> + // CHECK-LABEL: func.func @matmul( // CHECK-SAME: %[[VAL_0:.*0]]: tensor<?x?xf16>, // CHECK-SAME: %[[VAL_1:.*1]]: tensor<?x?xf16>, @@ -51,18 +59,14 @@ // CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<?x?xf16> // CHECK: return %[[VAL_55]] : tensor<?x?xf16> // CHECK: } - -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> module { - func.func @matmul(%arg0: tensor<?x?xf16>, %arg1: tensor<?x?xf16>, %arg2: tensor<?x?xf16>) -> tensor<?x?xf16> { - %0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>) outs(%arg2 : tensor<?x?xf16>) { - ^bb0(%in: f16, %in_0: f16, %out: f16): - %1 = arith.mulf %in, %in_0 : f16 - %2 = arith.addf %out, %1 : f16 - linalg.yield %2 : f16 - } -> tensor<?x?xf16> - return %0 : tensor<?x?xf16> + func.func @matmul(%Ad: tensor<?x?xf16>, + %B: tensor<?x?xf16>, + %Cin: tensor<?x?xf16>) -> tensor<?x?xf16> { + %A = sparse_tensor.convert %Ad : tensor<?x?xf16> to tensor<?x?xf16, #NV_24> + %C = linalg.matmul + ins(%A, %B: tensor<?x?xf16, #NV_24>, tensor<?x?xf16>) + outs(%Cin: tensor<?x?xf16>) -> tensor<?x?xf16> + return %C : tensor<?x?xf16> } } diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir index bdfe18a..b78ab9b 100644 --- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir +++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir @@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32> return %0 : tensor<8x5x32xf32> } + +// ----- + +// CHECK-LABEL: func.func @unpack_1d_to_collapse +// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>) +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32> +// CHECK: return %[[COLLAPSED]] +func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> { + %empty = tensor.empty() : tensor<256xf32> + %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32> + return %0 : tensor<256xf32> +} + +// ----- + +// CHECK-LABEL: func.func @unpack_to_partial_slice +// CHECK-NOT: tensor.collapse +// CHECK: tensor.unpack +func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> { + %empty = tensor.empty() : tensor<255xf32> + %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32> + return %0 : tensor<255xf32> +} + +// ----- + +// CHECK-LABEL: func.func @unpack_dynamic +// CHECK-NOT: tensor.collapse +// CHECK: tensor.unpack +func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32> + %size = arith.muli %d0, %c32 : index + %empty = tensor.empty(%size) : tensor<?xf32> + %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32> + return %0 : tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func.func @single_last_inner_dim_unpacking( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x8x32xf32>) +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32> +// CHECK: return %[[COLLAPSED]] : tensor<5x256xf32> +func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> { + %empty = tensor.empty() : tensor<5x256xf32> + %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32> + return %0 : tensor<5x256xf32> +} + +// ----- + +// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm( +// CHECK-NOT: tensor.collpase_shape +// CHECK: tensor.unpack +func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> { + %empty = tensor.empty() : tensor<5x256xf32> + %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32> + return %0 : tensor<5x256xf32> +} + +// ----- + +// CHECK-LABEL: func.func @single_first_inner_dim_unpacking( +// CHECK-NOT: tensor.collapse_shape +// CHECK: tensor.unpack +func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> { + %empty = tensor.empty() : tensor<256x5xf32> + %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32> + return %0 : tensor<256x5xf32> +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 0964161..5123958 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -696,3 +696,11 @@ transform.sequence failures(propagate) { transform.named_sequence @foo() } : !transform.any_op } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}} + transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32> +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index d9a1199..a39e6f9 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -575,8 +575,9 @@ transform.with_pdl_patterns { %0 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op %1 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op %2 = merge_handles deduplicate %0, %1 : !transform.any_op + %3 = num_associations %2 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{1}} - test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op + test_print_param %3 : !transform.param<i64> } } @@ -676,11 +677,13 @@ module { ^bb0(%arg1: !transform.any_op): %0 = pdl_match @func in %arg1 : (!transform.any_op) -> !transform.any_op %1 = replicate num(%0) %arg1 : !transform.any_op, !transform.any_op + %p = num_associations %1 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{2}} - test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op + test_print_param %p : !transform.param<i64> %2 = replicate num(%0) %1 : !transform.any_op, !transform.any_op + %p2 = num_associations %2 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{4}} - test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op + test_print_param %p2 : !transform.param<i64> } } } @@ -708,8 +711,9 @@ transform.with_pdl_patterns { %f = pdl_match @const in %arg1 : (!transform.any_op) -> !transform.any_op transform.foreach %f : !transform.any_op { ^bb2(%arg2: !transform.any_op): + %p = transform.num_associations %arg2 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{1}} - transform.test_print_number_of_associated_payload_ir_ops %arg2 : !transform.any_op + transform.test_print_param %p : !transform.param<i64> transform.test_print_remark_at_operand %arg2, "transform applied" : !transform.any_op } } @@ -780,8 +784,9 @@ transform.with_pdl_patterns { transform.yield %g : !transform.any_op } + %p = transform.num_associations %results : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{3}} - transform.test_print_number_of_associated_payload_ir_ops %results : !transform.any_op + transform.test_print_param %p : !transform.param<i64> transform.test_print_remark_at_operand %results, "transform applied" : !transform.any_op } } @@ -877,8 +882,9 @@ transform.sequence failures(propagate) { ^bb1(%fun: !transform.any_op): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op %h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{1}} - transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op + transform.test_print_param %p : !transform.param<i64> %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}} %h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) @@ -896,13 +902,15 @@ transform.sequence failures(suppress) { ^bb1(%fun: !transform.any_op): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op %h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{1}} - transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op + transform.test_print_param %p : !transform.param<i64> %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op // Silenceable failure and all handles are now empty. %h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %p2 = transform.num_associations %h_2#0 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{0}} - transform.test_print_number_of_associated_payload_ir_ops %h_2#0 : !transform.any_op + transform.test_print_param %p2 : !transform.param<i64> } // ----- @@ -918,12 +926,15 @@ transform.sequence failures(propagate) { %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op // No error, last result handle is empty. %h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{1}} - transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op + transform.test_print_param %p : !transform.param<i64> + %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{1}} - transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op + transform.test_print_param %p2 : !transform.param<i64> + %p3 = transform.num_associations %h#2 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{0}} - transform.test_print_number_of_associated_payload_ir_ops %h#2 : !transform.any_op + transform.test_print_param %p3 : !transform.param<i64> } // ----- @@ -940,10 +951,12 @@ transform.sequence failures(propagate) { ^bb1(%fun: !transform.any_op): %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op %h:2 = split_handle %muli_2 {overflow_result = 0} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{3}} - transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op + transform.test_print_param %p : !transform.param<i64> + %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{1}} - transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op + transform.test_print_param %p2 : !transform.param<i64> } // ----- @@ -1668,8 +1681,9 @@ transform.sequence failures(propagate) { // expected-remark @below {{2 iterations}} transform.test_tracked_rewrite %0 : (!transform.any_op) -> () // One replacement op (test.drop_mapping) is dropped from the mapping. + %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below {{2}} - test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op + test_print_param %p : !transform.param<i64> } // ----- @@ -1684,20 +1698,24 @@ module { %2 = transform.param.constant 1 -> !transform.param<i64> %3 = transform.param.constant 2 -> !transform.param<i64> %4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64> + %p = num_associations %4 : (!transform.param<i64>) -> !transform.param<i64> // expected-remark @below {{1}} - test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64> + test_print_param %p : !transform.param<i64> %5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64> + %p2 = num_associations %5 : (!transform.param<i64>) -> !transform.param<i64> // expected-remark @below {{1}} - test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64> + test_print_param %p2 : !transform.param<i64> %6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64> + %p3 = num_associations %6 : (!transform.param<i64>) -> !transform.param<i64> // expected-remark @below {{2}} - test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64> + test_print_param %p3 : !transform.param<i64> %7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64> + %p4 = num_associations %7 : (!transform.param<i64>) -> !transform.param<i64> // expected-remark @below {{4}} - test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64> + test_print_param %p4 : !transform.param<i64> } } @@ -1712,21 +1730,25 @@ transform.sequence failures(propagate) { %3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value %4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value + %p = num_associations %4 : (!transform.any_value) -> !transform.param<i64> // expected-remark @below {{1}} - test_print_number_of_associated_payload_ir_values %4 : !transform.any_value + test_print_param %p : !transform.param<i64> %5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value + %p2 = num_associations %5 : (!transform.any_value) -> !transform.param<i64> // expected-remark @below {{2}} - test_print_number_of_associated_payload_ir_values %5 : !transform.any_value + test_print_param %p2 : !transform.param<i64> %6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value %7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value + %p3 = num_associations %6 : (!transform.any_value) -> !transform.param<i64> // expected-remark @below {{1}} - test_print_number_of_associated_payload_ir_values %6 : !transform.any_value + test_print_param %p3 : !transform.param<i64> %8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value + %p4 = num_associations %8 : (!transform.any_value) -> !transform.param<i64> // expected-remark @below {{4}} - test_print_number_of_associated_payload_ir_values %8 : !transform.any_value + test_print_param %p4 : !transform.param<i64> } // ----- @@ -1820,31 +1842,37 @@ transform.sequence failures(propagate) { // There are 3 arith.constant ops. %all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op + %p = num_associations %all : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{3}} - test_print_number_of_associated_payload_ir_ops %all : !transform.any_op + test_print_param %p : !transform.param<i64> // "deduplicate" has no effect because these are 3 different ops. %merged_before = transform.merge_handles deduplicate %all : !transform.any_op + %p2 = num_associations %merged_before : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{3}} - test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op + test_print_param %p2 : !transform.param<i64> // Apply CSE. transform.apply_cse to %0 : !transform.any_op // The handle is still mapped to 3 arith.constant ops. + %p3 = num_associations %all : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{3}} - test_print_number_of_associated_payload_ir_ops %all : !transform.any_op + test_print_param %p3 : !transform.param<i64> // But they are all the same op. %merged_after = transform.merge_handles deduplicate %all : !transform.any_op + %p4 = num_associations %merged_after : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op + test_print_param %p4 : !transform.param<i64> // The other handles were also updated. test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op + %p5 = num_associations %elim_first : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op + test_print_param %p5 : !transform.param<i64> test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op + %p6 = num_associations %elim_second : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op + test_print_param %p6 : !transform.param<i64> } // ----- @@ -1907,14 +1935,16 @@ transform.sequence failures(propagate) { // Get immediate parent. %2 = transform.get_parent_op %0 : (!transform.any_op) -> !transform.any_op test_print_remark_at_operand %2, "direct parent" : !transform.any_op + %p = num_associations %2 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{2}} - test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op + test_print_param %p : !transform.param<i64> // Deduplicate results. %3 = transform.structured.match ops{["test.qux"]} in %arg1 : (!transform.any_op) -> !transform.any_op %4 = transform.get_parent_op %3 {deduplicate} : (!transform.any_op) -> !transform.any_op + %p2 = num_associations %4 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op + test_print_param %p2 : !transform.param<i64> } @@ -2029,8 +2059,9 @@ transform.sequence failures(propagate) { // Match all ops inside the function (including the function itself). %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op + %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{5}} - test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op + test_print_param %p : !transform.param<i64> // Select "test.foo". %foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op @@ -2060,8 +2091,9 @@ transform.sequence failures(propagate) { %empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op transform.apply_dce to %func_op : !transform.any_op + %p = num_associations %empty_op : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{0}} - test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op + test_print_param %p : !transform.param<i64> } diff --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir index 4259627..c34f4ba 100644 --- a/mlir/test/Dialect/Transform/test-loop-transforms.mlir +++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir @@ -37,13 +37,16 @@ module attributes {transform.with_named_sequence} { // Make sure that the handles are still valid (and were updated in case of // the loop). + %p = transform.num_associations %0 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op + transform.test_print_param %p : !transform.param<i64> transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op + %p2 = transform.num_associations %1 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op + transform.test_print_param %p2 : !transform.param<i64> + %p3 = transform.num_associations %2 : (!transform.any_op) -> !transform.param<i64> // expected-remark @below{{1}} - transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op + transform.test_print_param %p3 : !transform.param<i64> transform.yield } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 3708d74..ae457ea 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -356,3 +356,18 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>, // CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32> // CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32> // CHECK: return %[[VAL_4]] : vector<8xi32> + +// ----- + +// This test is to make sure there is no crash for empty stride. +func.func @stride_empty_test(%1: memref<i16>) -> vector<32x256xi16> { + %c0_i16 = arith.constant 0 : i16 + %3 = vector.transfer_read %1[], %c0_i16 {permutation_map = affine_map<() -> (0, 0)>} : memref<i16>, vector<32x256xi16> + return %3 : vector<32x256xi16> + + // CHECK-LABEL: func.func @stride_empty_test + // CHECK: %[[VAL:.*]] = arith.constant 0 : i16 + // CHECK: %[[RET:.*]] = vector.transfer_read {{.*}} vector<32x256xi16> + // CHECK: return %[[RET]] + // CHECK-NOT: empty() +} |