aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/Dialect')
-rw-r--r--mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir190
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir15
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-match.mlir5
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad.mlir3
-rw-r--r--mlir/test/Dialect/Mesh/invalid.mlir96
-rw-r--r--mlir/test/Dialect/Mesh/ops.mlir49
-rw-r--r--mlir/test/Dialect/Mesh/resharding-spmdization.mlir154
-rw-r--r--mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (renamed from mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir)28
-rw-r--r--mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir72
-rw-r--r--mlir/test/Dialect/Transform/ops-invalid.mlir8
-rw-r--r--mlir/test/Dialect/Transform/test-interpreter.mlir96
-rw-r--r--mlir/test/Dialect/Transform/test-loop-transforms.mlir9
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-flatten.mlir15
13 files changed, 650 insertions, 90 deletions
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
index b714607..f04a01f 100644
--- a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -1,71 +1,191 @@
-// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-gpu-subgroup-reduce-lowering %s \
+// RUN: | FileCheck %s --check-prefix=CHECK-SUB
-// CHECK: gpu.module @kernels {
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-gpu-subgroup-reduce-lowering="expand-to-shuffles" %s \
+// RUN: | FileCheck %s --check-prefix=CHECK-SHFL
+
+// CHECK-SUB: gpu.module @kernels {
+// CHECK-SHFL: gpu.module @kernels {
gpu.module @kernels {
- // CHECK-LABEL: gpu.func @kernel0(
- // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+ // CHECK-SUB-LABEL: gpu.func @kernel0(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<5xf16>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel0(
gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
- // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
- // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
- // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
- // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
- // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
- // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
- // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
- // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
- // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
- // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+ // CHECK-SUB: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+ // CHECK-SUB: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK-SUB: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK-SUB: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK-SUB: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK-SUB: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+ // CHECK-SUB: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+ // CHECK-SUB: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+ // CHECK-SUB: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum0) : (vector<5xf16>) -> ()
-
- // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
- // CHECK: "test.consume"
+ // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+ // CHECK-SUB: "test.consume"
%sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum1) : (vector<5xf16>) -> ()
- // CHECK: gpu.return
+ // CHECK-SUB: gpu.return
gpu.return
}
- // CHECK-LABEL: gpu.func @kernel1(
- // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+ // CHECK-SUB-LABEL: gpu.func @kernel1(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<1xf32>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel1(
gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
- // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
- // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
- // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+ // CHECK-SUB: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+ // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+ // CHECK-SUB: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+ // CHECK-SUB: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum0) : (vector<1xf32>) -> ()
- // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
- // CHECK: "test.consume"
+ // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+ // CHECK-SUB: "test.consume"
%sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum1) : (vector<1xf32>) -> ()
- // CHECK: gpu.return
+ // CHECK-SUB: gpu.return
gpu.return
}
// These vectors fit the native shuffle size and should not be broken down.
//
- // CHECK-LABEL: gpu.func @kernel2(
- // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+ // CHECK-SUB-LABEL: gpu.func @kernel2(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel2(
gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
- // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+ // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
+ // CHECK-SUB: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
"test.consume"(%sum0) : (vector<3xi8>) -> ()
- // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
- // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
+ // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
+ // CHECK-SUB: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
%sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
"test.consume"(%sum1) : (vector<4xi8>) -> ()
- // CHECK: gpu.return
+ // CHECK-SUB: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel3(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: i32)
+ gpu.func @kernel3(%arg0: i32) kernel {
+ // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
+ // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
+ // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32
+ // CHECK-SHFL-DAG: %[[C8:.+]] = arith.constant 8 : i32
+ // CHECK-SHFL-DAG: %[[C16:.+]] = arith.constant 16 : i32
+ // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[ARG0]], %[[C1]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[S0]] : i32
+ // CHECK-SHFL: %[[S1:.+]], %{{.+}} = gpu.shuffle xor %[[A0]], %[[C2]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A1:.+]] = arith.addi %[[A0]], %[[S1]] : i32
+ // CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
+ // CHECK-SHFL: %[[S3:.+]], %{{.+}} = gpu.shuffle xor %[[A2]], %[[C8]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A3:.+]] = arith.addi %[[A2]], %[[S3]] : i32
+ // CHECK-SHFL: %[[S4:.+]], %{{.+}} = gpu.shuffle xor %[[A3]], %[[C16]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A4:.+]] = arith.addi %[[A3]], %[[S4]] : i32
+ // CHECK-SHFL: "test.consume"(%[[A4]]) : (i32) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (i32) -> i32
+ "test.consume"(%sum0) : (i32) -> ()
+
+ // CHECK-SHFL: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel4(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: vector<2xf16>)
+ gpu.func @kernel4(%arg0: vector<2xf16>) kernel {
+ // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
+ // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
+ // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32
+ // CHECK-SHFL-DAG: %[[C8:.+]] = arith.constant 8 : i32
+ // CHECK-SHFL-DAG: %[[C16:.+]] = arith.constant 16 : i32
+ // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+ // CHECK-SHFL: %[[V0:.+]] = vector.bitcast %[[ARG0]] : vector<2xf16> to vector<1xi32>
+ // CHECK-SHFL: %[[I0:.+]] = vector.extract %[[V0]][0] : i32 from vector<1xi32>
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[I0]], %[[C1]], %[[C32]] : i32
+ // CHECK-SHFL: %[[BR0:.+]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+ // CHECK-SHFL: %[[BC0:.+]] = vector.bitcast %[[BR0]] : vector<1xi32> to vector<2xf16>
+ // CHECK-SHFL: %[[ADD0:.+]] = arith.addf %[[ARG0]], %[[BC0]] : vector<2xf16>
+ // CHECK-SHFL: %[[BC1:.+]] = vector.bitcast %[[ADD0]] : vector<2xf16> to vector<1xi32>
+ // CHECK-SHFL: %[[I1:.+]] = vector.extract %[[BC1]][0] : i32 from vector<1xi32>
+ // CHECK-SHFL: gpu.shuffle xor %[[I1]], %[[C2]], %[[C32]] : i32
+ // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+ // CHECK-SHFL: gpu.shuffle xor %{{.+}}, %[[C4]], %[[C32]] : i32
+ // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+ // CHECK-SHFL: gpu.shuffle xor %{{.+}}, %[[C8]], %[[C32]] : i32
+ // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+ // CHECK-SHFL: %[[SL:.+]], %{{.+}} = gpu.shuffle xor %{{.+}}, %[[C16]], %[[C32]] : i32
+ // CHECK-SHFL: %[[BRL:.+]] = vector.broadcast %[[SL]] : i32 to vector<1xi32>
+ // CHECK-SHFL: %[[BCL:.+]] = vector.bitcast %[[BRL]] : vector<1xi32> to vector<2xf16>
+ // CHECK-SHFL: %[[ADDL:.+]] = arith.addf %{{.+}}, %[[BCL]] : vector<2xf16>
+ // CHECK-SHFL: "test.consume"(%[[ADDL]]) : (vector<2xf16>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<2xf16>) -> (vector<2xf16>)
+ "test.consume"(%sum0) : (vector<2xf16>) -> ()
+
+ // CHECK-SHFL: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel5(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: i16)
+ gpu.func @kernel5(%arg0: i16) kernel {
+ // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32
+ // CHECK-SHFL: %[[T0:.+]] = arith.trunci %[[S0]] : i32 to i16
+ // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[T0]] : i16
+ // CHECK-SHFL: %[[E1:.+]] = arith.extui %[[A0]] : i16 to i32
+ // CHECK-SHFL: %{{.+}}, %{{.+}} = gpu.shuffle xor %[[E1]], {{.+}} : i32
+ // CHECK-SHFL-COUNT-3: gpu.shuffle xor
+ // CHECK-SHFL: arith.trunci {{.+}} : i32 to i16
+ // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
+ // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (i16) -> i16
+ "test.consume"(%sum0) : (i16) -> ()
+
+ // CHECK-SHFL: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel6(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: vector<3xi8>)
+ gpu.func @kernel6(%arg0: vector<3xi8>) kernel {
+ // CHECK-SHFL: %[[CZ:.+]] = arith.constant dense<0> : vector<4xi8>
+ // CHECK-SHFL: %[[V0:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CZ]] {offsets = [0], strides = [1]} : vector<3xi8> into vector<4xi8>
+ // CHECK-SHFL: %[[BC0:.+]] = vector.bitcast %[[V0]] : vector<4xi8> to vector<1xi32>
+ // CHECK-SHFL: %[[I0:.+]] = vector.extract %[[BC0]][0] : i32 from vector<1xi32>
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[I0]], {{.+}} : i32
+ // CHECK-SHFL: %[[BR0:.+]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+ // CHECK-SHFL: %[[BC1:.+]] = vector.bitcast %[[BR0]] : vector<1xi32> to vector<4xi8>
+ // CHECK-SHFL: %[[ADD0:.+]] = arith.addi %[[V0]], %[[BC1]] : vector<4xi8>
+ // CHECK-SHFL: %[[BC2:.+]] = vector.bitcast %[[ADD0]] : vector<4xi8> to vector<1xi32>
+ // CHECK-SHFL: %[[I1:.+]] = vector.extract %[[BC2]][0] : i32 from vector<1xi32>
+ // CHECK-SHFL-COUNT-4: gpu.shuffle xor
+ // CHECK-SHFL: %[[ESS:.+]] = vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [3], strides = [1]} : vector<4xi8> to vector<3xi8>
+ // CHECK-SHFL: "test.consume"(%[[ESS]]) : (vector<3xi8>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
+ "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+ // CHECK-SHFL: gpu.return
gpu.return
}
}
+
diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index 49a52ba..aa15ccf 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} {
// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
// Ensure that one linalg.copy was generated.
%mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
+ %p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
transform.yield
}
}
@@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} {
// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
// Ensure that one linalg.copy was generated.
%linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op
+ %p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
// Ensure that one memref.alloca was generated.
%alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
+ %p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
// Make sure that One-Shot Bufferize can bufferize the rest.
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 15942db..db5b5f1 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} {
#linalg.iterator_type<parallel>,
#linalg.iterator_type<reduction>]}
in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op
+ %p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64>
+ // expected-remark @below {{0}}
+ transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 6bca6c1..1f9d81a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} {
padding_dimensions=[0, 1, 2],
pack_paddings=[1, 1, 0]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
+ %p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
+ transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 03994f8..3ee578a 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -70,6 +70,102 @@ func.func @mesh_axis_negtive_in_partial(
// -----
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
+ // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
+ // expected-error@+1 {{Mesh axes contains duplicate elements.}}
+ %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
+ // expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+ // expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
+ %0:2 = mesh.cluster_shape @mesh0 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @cluster_shape_invalid_mesh_name() -> (index) {
+ // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index
+ return %0#0 : index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
+ // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
+ // expected-error@+1 {{Mesh axes contains duplicate elements.}}
+ %0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_wrong_number_of_results() -> (index, index) {
+ // expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
+ %0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+ // expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
+ %0:2 = mesh.process_index on @mesh0 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @process_index_invalid_mesh_name() -> (index) {
+ // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
+ return %0#0 : index
+}
+
+// -----
+
func.func @all_reduce_invalid_mesh_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 8f8e309..a7c3b3d 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -132,6 +132,55 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
+// CHECK-LABEL: func @cluster_shape
+func.func @cluster_shape() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_default_axes
+func.func @cluster_shape_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+ %0:3 = mesh.cluster_shape @mesh0 : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_empty_axes
+func.func @cluster_shape_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+ %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index
+func.func @process_index() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @process_index_default_axes
+func.func @process_index_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_index on @mesh0 : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index_empty_axes
+func.func @process_index_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+
// CHECK-LABEL: func @all_reduce
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
new file mode 100644
index 0000000..0ba0d76
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -0,0 +1,154 @@
+// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1, dim_sizes = 2)
+mesh.cluster @mesh_1d_dynamic(rank = 1, dim_sizes = ?)
+
+// CHECK-LABEL: func @same_source_and_target_sharding
+func.func @same_source_and_target_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+ %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32>
+ // CHECK: return %[[ARG]]
+ return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis
+func.func @split_replicated_tensor_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
+ %arg0: tensor<3x14xf32>
+) -> tensor<3x14xf32> {
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
+ // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
+ // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+ // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+ // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
+ // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
+ // CHECK: return %[[RESULT]] : tensor<3x14xf32>
+ return %1 : tensor<3x14xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
+func.func @split_replicated_tensor_axis_dynamic(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
+ %arg0: tensor<?x3x?xf32>
+) -> tensor<?x3x?xf32> {
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+ // CHECK: %[[TWO:.*]] = arith.constant 2 : index
+ // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
+ // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+ // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+ // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
+ // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
+ // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
+ // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+ return %1 : tensor<?x3x?xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis_dynamic_mesh
+func.func @move_split_axis_dynamic_mesh(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_dynamic_axis
+func.func @move_split_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[RES]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis
+func.func @unshard_static_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_dynamic_axis
+func.func @unshard_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
+func.func @unshard_static_axis_on_dynamic_mesh_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @partial_axis
+func.func @partial_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index f584977e..6fe7ec9 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -1,5 +1,13 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-gpu-codegen="num-threads=0" | FileCheck %s
+#NV_24 = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i : dense,
+ j floordiv 4 : dense,
+ j mod 4 : block2_4
+ )
+}>
+
// CHECK-LABEL: func.func @matmul(
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<?x?xf16>,
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<?x?xf16>,
@@ -51,18 +59,14 @@
// CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<?x?xf16>
// CHECK: return %[[VAL_55]] : tensor<?x?xf16>
// CHECK: }
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
- func.func @matmul(%arg0: tensor<?x?xf16>, %arg1: tensor<?x?xf16>, %arg2: tensor<?x?xf16>) -> tensor<?x?xf16> {
- %0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>) outs(%arg2 : tensor<?x?xf16>) {
- ^bb0(%in: f16, %in_0: f16, %out: f16):
- %1 = arith.mulf %in, %in_0 : f16
- %2 = arith.addf %out, %1 : f16
- linalg.yield %2 : f16
- } -> tensor<?x?xf16>
- return %0 : tensor<?x?xf16>
+ func.func @matmul(%Ad: tensor<?x?xf16>,
+ %B: tensor<?x?xf16>,
+ %Cin: tensor<?x?xf16>) -> tensor<?x?xf16> {
+ %A = sparse_tensor.convert %Ad : tensor<?x?xf16> to tensor<?x?xf16, #NV_24>
+ %C = linalg.matmul
+ ins(%A, %B: tensor<?x?xf16, #NV_24>, tensor<?x?xf16>)
+ outs(%Cin: tensor<?x?xf16>) -> tensor<?x?xf16>
+ return %C : tensor<?x?xf16>
}
}
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index bdfe18a..b78ab9b 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
return %0 : tensor<8x5x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_1d_to_collapse
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
+// CHECK: return %[[COLLAPSED]]
+func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
+ %empty = tensor.empty() : tensor<256xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
+ return %0 : tensor<256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_to_partial_slice
+// CHECK-NOT: tensor.collapse
+// CHECK: tensor.unpack
+func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
+ %empty = tensor.empty() : tensor<255xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
+ return %0 : tensor<255xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK-NOT: tensor.collapse
+// CHECK: tensor.unpack
+func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
+ %size = arith.muli %d0, %c32 : index
+ %empty = tensor.empty(%size) : tensor<?xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x8x32xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<5x256xf32>
+func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
+ %empty = tensor.empty() : tensor<5x256xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
+ return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
+// CHECK-NOT: tensor.collpase_shape
+// CHECK: tensor.unpack
+func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
+ %empty = tensor.empty() : tensor<5x256xf32>
+ %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
+ return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
+// CHECK-NOT: tensor.collapse_shape
+// CHECK: tensor.unpack
+func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
+ %empty = tensor.empty() : tensor<256x5xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
+ return %0 : tensor<256x5xf32>
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 0964161..5123958 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -696,3 +696,11 @@ transform.sequence failures(propagate) {
transform.named_sequence @foo()
} : !transform.any_op
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
+ transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index d9a1199..a39e6f9 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -575,8 +575,9 @@ transform.with_pdl_patterns {
%0 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
%2 = merge_handles deduplicate %0, %1 : !transform.any_op
+ %3 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %3 : !transform.param<i64>
}
}
@@ -676,11 +677,13 @@ module {
^bb0(%arg1: !transform.any_op):
%0 = pdl_match @func in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = replicate num(%0) %arg1 : !transform.any_op, !transform.any_op
+ %p = num_associations %1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
%2 = replicate num(%0) %1 : !transform.any_op, !transform.any_op
+ %p2 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
}
}
}
@@ -708,8 +711,9 @@ transform.with_pdl_patterns {
%f = pdl_match @const in %arg1 : (!transform.any_op) -> !transform.any_op
transform.foreach %f : !transform.any_op {
^bb2(%arg2: !transform.any_op):
+ %p = transform.num_associations %arg2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %arg2 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %arg2, "transform applied" : !transform.any_op
}
}
@@ -780,8 +784,9 @@ transform.with_pdl_patterns {
transform.yield %g : !transform.any_op
}
+ %p = transform.num_associations %results : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{3}}
- transform.test_print_number_of_associated_payload_ir_ops %results : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %results, "transform applied" : !transform.any_op
}
}
@@ -877,8 +882,9 @@ transform.sequence failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
%h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@@ -896,13 +902,15 @@ transform.sequence failures(suppress) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// Silenceable failure and all handles are now empty.
%h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %p2 = transform.num_associations %h_2#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h_2#0 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
}
// -----
@@ -918,12 +926,15 @@ transform.sequence failures(propagate) {
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// No error, last result handle is empty.
%h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
+ %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
+ %p3 = transform.num_associations %h#2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h#2 : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
}
// -----
@@ -940,10 +951,12 @@ transform.sequence failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli_2 {overflow_result = 0} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{3}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
+ %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
}
// -----
@@ -1668,8 +1681,9 @@ transform.sequence failures(propagate) {
// expected-remark @below {{2 iterations}}
transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
// One replacement op (test.drop_mapping) is dropped from the mapping.
+ %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
}
// -----
@@ -1684,20 +1698,24 @@ module {
%2 = transform.param.constant 1 -> !transform.param<i64>
%3 = transform.param.constant 2 -> !transform.param<i64>
%4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64>
+ %p = num_associations %4 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64>
+ test_print_param %p : !transform.param<i64>
%5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64>
+ %p2 = num_associations %5 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64>
+ test_print_param %p2 : !transform.param<i64>
%6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64>
+ %p3 = num_associations %6 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64>
+ test_print_param %p3 : !transform.param<i64>
%7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64>
+ %p4 = num_associations %7 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64>
+ test_print_param %p4 : !transform.param<i64>
}
}
@@ -1712,21 +1730,25 @@ transform.sequence failures(propagate) {
%3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value
%4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value
+ %p = num_associations %4 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_values %4 : !transform.any_value
+ test_print_param %p : !transform.param<i64>
%5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value
+ %p2 = num_associations %5 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_values %5 : !transform.any_value
+ test_print_param %p2 : !transform.param<i64>
%6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
%7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value
+ %p3 = num_associations %6 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_values %6 : !transform.any_value
+ test_print_param %p3 : !transform.param<i64>
%8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value
+ %p4 = num_associations %8 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_values %8 : !transform.any_value
+ test_print_param %p4 : !transform.param<i64>
}
// -----
@@ -1820,31 +1842,37 @@ transform.sequence failures(propagate) {
// There are 3 arith.constant ops.
%all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op
+ %p = num_associations %all : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// "deduplicate" has no effect because these are 3 different ops.
%merged_before = transform.merge_handles deduplicate %all : !transform.any_op
+ %p2 = num_associations %merged_before : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
// Apply CSE.
transform.apply_cse to %0 : !transform.any_op
// The handle is still mapped to 3 arith.constant ops.
+ %p3 = num_associations %all : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ test_print_param %p3 : !transform.param<i64>
// But they are all the same op.
%merged_after = transform.merge_handles deduplicate %all : !transform.any_op
+ %p4 = num_associations %merged_after : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op
+ test_print_param %p4 : !transform.param<i64>
// The other handles were also updated.
test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op
+ %p5 = num_associations %elim_first : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op
+ test_print_param %p5 : !transform.param<i64>
test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op
+ %p6 = num_associations %elim_second : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op
+ test_print_param %p6 : !transform.param<i64>
}
// -----
@@ -1907,14 +1935,16 @@ transform.sequence failures(propagate) {
// Get immediate parent.
%2 = transform.get_parent_op %0 : (!transform.any_op) -> !transform.any_op
test_print_remark_at_operand %2, "direct parent" : !transform.any_op
+ %p = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{2}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// Deduplicate results.
%3 = transform.structured.match ops{["test.qux"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%4 = transform.get_parent_op %3 {deduplicate} : (!transform.any_op) -> !transform.any_op
+ %p2 = num_associations %4 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
}
@@ -2029,8 +2059,9 @@ transform.sequence failures(propagate) {
// Match all ops inside the function (including the function itself).
%func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op
+ %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{5}}
- test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// Select "test.foo".
%foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op
@@ -2060,8 +2091,9 @@ transform.sequence failures(propagate) {
%empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op
transform.apply_dce to %func_op : !transform.any_op
+ %p = num_associations %empty_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{0}}
- test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
+ test_print_param %p : !transform.param<i64>
}
diff --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
index 4259627..c34f4ba 100644
--- a/mlir/test/Dialect/Transform/test-loop-transforms.mlir
+++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
@@ -37,13 +37,16 @@ module attributes {transform.with_named_sequence} {
// Make sure that the handles are still valid (and were updated in case of
// the loop).
+ %p = transform.num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+ %p2 = transform.num_associations %1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
+ %p3 = transform.num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
transform.yield
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 3708d74..ae457ea 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -356,3 +356,18 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+// This test is to make sure there is no crash for empty stride.
+func.func @stride_empty_test(%1: memref<i16>) -> vector<32x256xi16> {
+ %c0_i16 = arith.constant 0 : i16
+ %3 = vector.transfer_read %1[], %c0_i16 {permutation_map = affine_map<() -> (0, 0)>} : memref<i16>, vector<32x256xi16>
+ return %3 : vector<32x256xi16>
+
+ // CHECK-LABEL: func.func @stride_empty_test
+ // CHECK: %[[VAL:.*]] = arith.constant 0 : i16
+ // CHECK: %[[RET:.*]] = vector.transfer_read {{.*}} vector<32x256xi16>
+ // CHECK: return %[[RET]]
+ // CHECK-NOT: empty()
+}