aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect
diff options
context:
space:
mode:
authorRolf Morel <rolf.morel@huawei.com>2024-06-14 17:02:47 +0200
committerGitHub <noreply@github.com>2024-06-14 17:02:47 +0200
commitd462bf687548a5630f60a8afaa66120df8319e88 (patch)
tree28d1a15d2c6c8ed2c6edd3c9bde4d1aaf7aff986 /mlir/test/Dialect
parent0a57a20aa506c5a5a8b0a8eb45446d0747493d7c (diff)
downloadllvm-d462bf687548a5630f60a8afaa66120df8319e88.zip
llvm-d462bf687548a5630f60a8afaa66120df8319e88.tar.gz
llvm-d462bf687548a5630f60a8afaa66120df8319e88.tar.bz2
[mlir][Transform] Extend transform.foreach to take multiple arguments (#93705)
Changes transform.foreach's interface to take multiple arguments, e.g. transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2, %param): BODY } The semantics are that the payloads for these handles get iterated over as if the payloads have been zipped-up together - BODY gets executed once for each such tuple. The documentation explains that this implementation requires that the payloads have the same length. This change also enables the target argument(s) to be any op/value/param handle. The added test cases demonstrate some use cases for this change.
Diffstat (limited to 'mlir/test/Dialect')
-rw-r--r--mlir/test/Dialect/Linalg/multisize-tiling-full.mlir21
-rw-r--r--mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir73
-rw-r--r--mlir/test/Dialect/Transform/ops-invalid.mlir49
-rw-r--r--mlir/test/Dialect/Transform/ops.mlir22
-rw-r--r--mlir/test/Dialect/Transform/test-interpreter.mlir85
5 files changed, 238 insertions, 12 deletions
diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index 15b24b5..51332ff 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
- %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
- %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
- %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
- transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.foreach %5 : !transform.any_op {
+ ^bb0(%inner_linalg: !transform.any_op):
+ %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
+ %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
+ transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ }
transform.yield
}
}
@@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} {
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
- %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
- transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
- transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
+ ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
+ %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
+ transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ }
transform.yield
}
}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1c..54dd2bd 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -328,3 +328,76 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+// -----
+
+// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @foreach_loop_pair_fuse(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) {
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+ // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+ %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+ // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+ %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %5 = arith.addf %3, %2 : vector<16xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ } {target_loops}
+ %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+ %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+ %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+ scf.yield %dup6 : tensor<128xf32>
+ } {source_loops}
+ %2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+ // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+ %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+ %5 = arith.addf %3, %2 : vector<32xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ } {target_loops}
+ %dup2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+ %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+ %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+ %dup5 = arith.addf %dup3, %dup2 : vector<32xf32>
+ %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+ // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+ scf.yield %dup6 : tensor<128xf32>
+ } {source_loops}
+ return %1, %dup1, %2, %dup2 : tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>
+}
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %target_loops = transform.structured.match ops{["scf.for"]} attributes {target_loops} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %source_loops = transform.structured.match ops{["scf.for"]} attributes {source_loops} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.foreach %target_loops, %source_loops : !transform.any_op, !transform.any_op {
+ ^bb0(%target_loop: !transform.any_op, %source_loop: !transform.any_op):
+ %fused = transform.loop.fuse_sibling %target_loop into %source_loop : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ }
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 30a68cc..71a260f 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -279,6 +279,55 @@ transform.sequence failures(propagate) {
// -----
+transform.sequence failures(propagate) {
+ ^bb0(%root: !transform.any_op):
+ %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ // expected-error @below {{op expects the same number of targets as the body has block arguments}}
+ transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_value {
+ ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value):
+ transform.yield %op_arg, %val_arg : !transform.any_op, !transform.any_value
+ }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+ ^bb0(%root: !transform.any_op):
+ %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ // expected-error @below {{op expects co-indexed targets and the body's block arguments to have the same op/value/param type}}
+ transform.foreach %op : !transform.any_op -> !transform.any_value {
+ ^bb1(%val_arg: !transform.any_value):
+ transform.yield %val_arg : !transform.any_value
+ }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+ ^bb0(%root: !transform.any_op):
+ %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ // expected-error @below {{op expects the same number of results as the yield terminator has operands}}
+ transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_op {
+ ^bb1(%arg_op: !transform.any_op):
+ transform.yield %arg_op : !transform.any_op
+ }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+ ^bb0(%root: !transform.any_op):
+ %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ %val = transform.test_produce_value_handle_to_self_operand %op : (!transform.any_op) -> !transform.any_value
+ // expected-error @below {{expects co-indexed results and yield operands to have the same op/value/param type}}
+ transform.foreach %op, %val : !transform.any_op, !transform.any_value -> !transform.any_op, !transform.any_value {
+ ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value):
+ transform.yield %val_arg, %op_arg : !transform.any_value, !transform.any_op
+ }
+}
+
+// -----
+
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index b03a9f4..e9baffd 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -68,11 +68,25 @@ transform.sequence failures(propagate) {
}
// CHECK: transform.sequence
-// CHECK: foreach
transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op):
- transform.foreach %arg0 : !transform.any_op {
- ^bb1(%arg1: !transform.any_op):
+^bb0(%op0: !transform.any_op, %val0: !transform.any_value, %par0: !transform.any_param):
+ // CHECK: foreach %{{.*}} : !transform.any_op
+ transform.foreach %op0 : !transform.any_op {
+ ^bb1(%op1: !transform.any_op):
+ }
+ // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param
+ transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param {
+ ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+ }
+ // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op
+ transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op {
+ ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+ transform.yield %op1 : !transform.any_op
+ }
+ // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value
+ transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value {
+ ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+ transform.yield %par1, %val1 : !transform.any_param, !transform.any_value
}
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b6850e2..4fe2dbe 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -830,6 +830,91 @@ module attributes {transform.with_named_sequence} {
// -----
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %results, %types = transform.foreach %0 : !transform.any_op -> !transform.any_value, !transform.any_param {
+ ^bb0(%op0 : !transform.any_op):
+ %result = transform.get_result %op0[0] : (!transform.any_op) -> !transform.any_value
+ %type = transform.get_type elemental %result : (!transform.any_value) -> !transform.any_param
+ transform.yield %result, %type : !transform.any_value, !transform.any_param
+ }
+ transform.debug.emit_remark_at %results, "result selected" : !transform.any_value
+ transform.debug.emit_param_as_remark %types, "elemental types" at %0 : !transform.any_param, !transform.any_op
+
+ transform.yield
+ }
+}
+
+func.func @payload(%lhs: tensor<10x20xf16>,
+ %rhs: tensor<20x15xf32>) -> (tensor<10x15xf64>, tensor<10x15xf32>) {
+ %cst64 = arith.constant 0.0 : f64
+ %empty64 = tensor.empty() : tensor<10x15xf64>
+ %fill64 = linalg.fill ins(%cst64 : f64) outs(%empty64 : tensor<10x15xf64>) -> tensor<10x15xf64>
+ // expected-remark @below {{result selected}}
+ // expected-note @below {{value handle points to an op result #0}}
+ // expected-remark @below {{elemental types f64, f32}}
+ %result64 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>)
+ outs(%fill64: tensor<10x15xf64>) -> tensor<10x15xf64>
+
+ %cst32 = arith.constant 0.0 : f32
+ %empty32 = tensor.empty() : tensor<10x15xf32>
+ %fill32 = linalg.fill ins(%cst32 : f32) outs(%empty32 : tensor<10x15xf32>) -> tensor<10x15xf32>
+ // expected-remark @below {{result selected}}
+ // expected-note @below {{value handle points to an op result #0}}
+ // expected-remark @below {{elemental types f64, f32}}
+ %result32 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>)
+ outs(%fill32: tensor<10x15xf32>) -> tensor<10x15xf32>
+
+ return %result64, %result32 : tensor<10x15xf64>, tensor<10x15xf32>
+
+}
+
+// -----
+
+func.func @two_const_ops() {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %two_ops = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %one_param = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+ // expected-error @below {{prior targets' payload size (2) differs from payload size (1) of target}}
+ transform.foreach %two_ops, %one_param : !transform.any_op, !transform.test_dialect_param {
+ ^bb2(%op: !transform.any_op, %param: !transform.test_dialect_param):
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @one_const_op() {
+ %0 = arith.constant 0 : index
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %one_op = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %one_val = transform.test_produce_value_handle_to_self_operand %one_op : (!transform.any_op) -> !transform.any_value
+ %param_one = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+ %param_two = transform.param.constant 2 : i32 -> !transform.test_dialect_param
+ %two_params = transform.merge_handles %param_one, %param_two : !transform.test_dialect_param
+
+ // expected-error @below {{prior targets' payload size (1) differs from payload size (2) of target}}
+ transform.foreach %one_val, %one_op, %two_params : !transform.any_value, !transform.any_op, !transform.test_dialect_param {
+ ^bb2(%val: !transform.any_value, %op: !transform.any_op, %param: !transform.test_dialect_param):
+ }
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @consume_in_foreach()
// CHECK-NEXT: return
func.func @consume_in_foreach() {