diff options
author | Rolf Morel <rolf.morel@huawei.com> | 2024-06-14 17:02:47 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-14 17:02:47 +0200 |
commit | d462bf687548a5630f60a8afaa66120df8319e88 (patch) | |
tree | 28d1a15d2c6c8ed2c6edd3c9bde4d1aaf7aff986 /mlir/test/Dialect | |
parent | 0a57a20aa506c5a5a8b0a8eb45446d0747493d7c (diff) | |
download | llvm-d462bf687548a5630f60a8afaa66120df8319e88.zip llvm-d462bf687548a5630f60a8afaa66120df8319e88.tar.gz llvm-d462bf687548a5630f60a8afaa66120df8319e88.tar.bz2 |
[mlir][Transform] Extend transform.foreach to take multiple arguments (#93705)
Changes transform.foreach's interface to take multiple arguments, e.g.
transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2,
%param): BODY } The semantics are that the payloads for these handles
get iterated over as if the payloads have been zipped-up together - BODY
gets executed once for each such tuple. The documentation explains that
this implementation requires that the payloads have the same length.
This change also enables the target argument(s) to be any op/value/param
handle.
The added test cases demonstrate some use cases for this change.
Diffstat (limited to 'mlir/test/Dialect')
-rw-r--r-- | mlir/test/Dialect/Linalg/multisize-tiling-full.mlir | 21 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir | 73 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/ops-invalid.mlir | 49 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/ops.mlir | 22 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/test-interpreter.mlir | 85 |
5 files changed, 238 insertions, 12 deletions
diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir index 15b24b5..51332ff 100644 --- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir @@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op - %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op - %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op - %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op - transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.foreach %5 : !transform.any_op { + ^bb0(%inner_linalg: !transform.any_op): + %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op + %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op + transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + } transform.yield } } @@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} { %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> - %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64> - transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) - transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) + transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> { + ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>): + %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64> + transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) + transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) + } transform.yield } } diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 0f51b1c..54dd2bd 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -328,3 +328,76 @@ module attributes {transform.with_named_sequence} { transform.yield } } +// ----- + +// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +func.func @foreach_loop_pair_fuse(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) { + // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index + // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index + // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index + // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}} + %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) { + // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]] + // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]] + // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]] + // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]] + %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %5 = arith.addf %3, %2 : vector<16xf32> + %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + scf.yield %6 : tensor<128xf32> + } {target_loops} + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) { + // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]] + // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]] + // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]] + %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> + %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}} + scf.yield %dup6 : tensor<128xf32> + } {source_loops} + %2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) { + // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]] + // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]] + // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]] + // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]] + %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32> + %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32> + %5 = arith.addf %3, %2 : vector<32xf32> + %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32> + scf.yield %6 : tensor<128xf32> + } {target_loops} + %dup2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) { + // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]] + // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]] + // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]] + %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32> + %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32> + %dup5 = arith.addf %dup3, %dup2 : vector<32xf32> + %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32> + // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}} + scf.yield %dup6 : tensor<128xf32> + } {source_loops} + return %1, %dup1, %2, %dup2 : tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32> +} + + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %target_loops = transform.structured.match ops{["scf.for"]} attributes {target_loops} in %arg0 : (!transform.any_op) -> !transform.any_op + %source_loops = transform.structured.match ops{["scf.for"]} attributes {source_loops} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.foreach %target_loops, %source_loops : !transform.any_op, !transform.any_op { + ^bb0(%target_loop: !transform.any_op, %source_loop: !transform.any_op): + %fused = transform.loop.fuse_sibling %target_loop into %source_loop : (!transform.any_op,!transform.any_op) -> !transform.any_op + } + transform.yield + } +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 30a68cc..71a260f 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -279,6 +279,55 @@ transform.sequence failures(propagate) { // ----- +transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op + // expected-error @below {{op expects the same number of targets as the body has block arguments}} + transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_value { + ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value): + transform.yield %op_arg, %val_arg : !transform.any_op, !transform.any_value + } +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op + // expected-error @below {{op expects co-indexed targets and the body's block arguments to have the same op/value/param type}} + transform.foreach %op : !transform.any_op -> !transform.any_value { + ^bb1(%val_arg: !transform.any_value): + transform.yield %val_arg : !transform.any_value + } +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op + // expected-error @below {{op expects the same number of results as the yield terminator has operands}} + transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_op { + ^bb1(%arg_op: !transform.any_op): + transform.yield %arg_op : !transform.any_op + } +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op + %val = transform.test_produce_value_handle_to_self_operand %op : (!transform.any_op) -> !transform.any_value + // expected-error @below {{expects co-indexed results and yield operands to have the same op/value/param type}} + transform.foreach %op, %val : !transform.any_op, !transform.any_value -> !transform.any_op, !transform.any_value { + ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value): + transform.yield %val_arg, %op_arg : !transform.any_value, !transform.any_op + } +} + +// ----- + transform.sequence failures(suppress) { ^bb0(%arg0: !transform.any_op): // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}} diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir index b03a9f4..e9baffd 100644 --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -68,11 +68,25 @@ transform.sequence failures(propagate) { } // CHECK: transform.sequence -// CHECK: foreach transform.sequence failures(propagate) { -^bb0(%arg0: !transform.any_op): - transform.foreach %arg0 : !transform.any_op { - ^bb1(%arg1: !transform.any_op): +^bb0(%op0: !transform.any_op, %val0: !transform.any_value, %par0: !transform.any_param): + // CHECK: foreach %{{.*}} : !transform.any_op + transform.foreach %op0 : !transform.any_op { + ^bb1(%op1: !transform.any_op): + } + // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param + transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param { + ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param): + } + // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op + transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op { + ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param): + transform.yield %op1 : !transform.any_op + } + // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value + transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value { + ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param): + transform.yield %par1, %val1 : !transform.any_param, !transform.any_value } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index b6850e2..4fe2dbe 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -830,6 +830,91 @@ module attributes {transform.with_named_sequence} { // ----- +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %results, %types = transform.foreach %0 : !transform.any_op -> !transform.any_value, !transform.any_param { + ^bb0(%op0 : !transform.any_op): + %result = transform.get_result %op0[0] : (!transform.any_op) -> !transform.any_value + %type = transform.get_type elemental %result : (!transform.any_value) -> !transform.any_param + transform.yield %result, %type : !transform.any_value, !transform.any_param + } + transform.debug.emit_remark_at %results, "result selected" : !transform.any_value + transform.debug.emit_param_as_remark %types, "elemental types" at %0 : !transform.any_param, !transform.any_op + + transform.yield + } +} + +func.func @payload(%lhs: tensor<10x20xf16>, + %rhs: tensor<20x15xf32>) -> (tensor<10x15xf64>, tensor<10x15xf32>) { + %cst64 = arith.constant 0.0 : f64 + %empty64 = tensor.empty() : tensor<10x15xf64> + %fill64 = linalg.fill ins(%cst64 : f64) outs(%empty64 : tensor<10x15xf64>) -> tensor<10x15xf64> + // expected-remark @below {{result selected}} + // expected-note @below {{value handle points to an op result #0}} + // expected-remark @below {{elemental types f64, f32}} + %result64 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) + outs(%fill64: tensor<10x15xf64>) -> tensor<10x15xf64> + + %cst32 = arith.constant 0.0 : f32 + %empty32 = tensor.empty() : tensor<10x15xf32> + %fill32 = linalg.fill ins(%cst32 : f32) outs(%empty32 : tensor<10x15xf32>) -> tensor<10x15xf32> + // expected-remark @below {{result selected}} + // expected-note @below {{value handle points to an op result #0}} + // expected-remark @below {{elemental types f64, f32}} + %result32 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) + outs(%fill32: tensor<10x15xf32>) -> tensor<10x15xf32> + + return %result64, %result32 : tensor<10x15xf64>, tensor<10x15xf32> + +} + +// ----- + +func.func @two_const_ops() { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %two_ops = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %one_param = transform.param.constant 1 : i32 -> !transform.test_dialect_param + // expected-error @below {{prior targets' payload size (2) differs from payload size (1) of target}} + transform.foreach %two_ops, %one_param : !transform.any_op, !transform.test_dialect_param { + ^bb2(%op: !transform.any_op, %param: !transform.test_dialect_param): + } + transform.yield + } +} + +// ----- + +func.func @one_const_op() { + %0 = arith.constant 0 : index + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %one_op = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %one_val = transform.test_produce_value_handle_to_self_operand %one_op : (!transform.any_op) -> !transform.any_value + %param_one = transform.param.constant 1 : i32 -> !transform.test_dialect_param + %param_two = transform.param.constant 2 : i32 -> !transform.test_dialect_param + %two_params = transform.merge_handles %param_one, %param_two : !transform.test_dialect_param + + // expected-error @below {{prior targets' payload size (1) differs from payload size (2) of target}} + transform.foreach %one_val, %one_op, %two_params : !transform.any_value, !transform.any_op, !transform.test_dialect_param { + ^bb2(%val: !transform.any_value, %op: !transform.any_op, %param: !transform.test_dialect_param): + } + transform.yield + } +} + +// ----- + // CHECK-LABEL: func @consume_in_foreach() // CHECK-NEXT: return func.func @consume_in_foreach() { |