aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformOps.td57
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp120
-rw-r--r--mlir/test/Dialect/Linalg/multisize-tiling-full.mlir21
-rw-r--r--mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir73
-rw-r--r--mlir/test/Dialect/Transform/ops-invalid.mlir49
-rw-r--r--mlir/test/Dialect/Transform/ops.mlir22
-rw-r--r--mlir/test/Dialect/Transform/test-interpreter.mlir85
7 files changed, 354 insertions, 73 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 77048a2..3bb297c 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -614,43 +614,48 @@ def ForeachOp : TransformDialectOp<"foreach",
"getSuccessorRegions", "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
- let summary = "Executes the body for each payload op";
+ let summary = "Executes the body for each element of the payload";
let description = [{
- This op has exactly one region with exactly one block ("body"). The body is
- executed for each payload op that is associated to the target operand in an
- unbatched fashion. I.e., the block argument ("iteration variable") is always
- mapped to exactly one payload op.
-
- This op always reads the target handle. Furthermore, it consumes the handle
- if there is a transform op in the body that consumes the iteration variable.
- This op does not return anything.
-
- The transformations inside the body are applied in order of their
- appearance. During application, if any transformation in the sequence fails,
- the entire sequence fails immediately leaving the payload IR in potentially
- invalid state, i.e., this operation offers no transformation rollback
- capabilities.
-
- This op generates as many handles as the terminating YieldOp has operands.
- For each result, the payload ops of the corresponding YieldOp operand are
- merged and mapped to the same resulting handle.
+ Execute the op's body - its single region block - exactly once per
+ element of the payload associated to a target handle. The body's
+ transformations are applied in order of appearance until reaching the
+ (implicit) YieldOp terminator.
+
+ Each iteration gets executed by co-indexing the payloads of the arguments
+ and mapping the body's arguments to these tuples, as though iterating over
+ the zipped together `targets`. As such, in each iteration, the size of the
+ payload of each of the body's block arguments is exactly one.
+
+ This op always reads the target handles. Furthermore, it consumes a handle
+ if there is a transform op in the body that consumes the corresponding
+ block argument. Handles can point to ops, values, or parameters.
+
+ #### Return Modes
+
+ This op produces as many result handles as the body's terminating YieldOp
+ has operands. For each result, the payloads of the corresponding YieldOp
+ operand are merged and mapped to the same resulting handle.
+
+ If the target handles do not associate payloads of the same size, a
+ silencable failure will be generated.
+
+ During application, if any transformation in the sequence fails, the entire
+ sequence fails immediately with the same failure, leaving the payload IR in
+ a potentially invalid state, i.e., this operation offers no transformation
+ rollback capabilities.
}];
- let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+ let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
- "$target `:` type($target) (`->` type($results)^)? $body attr-dict";
+ "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
static StringRef getDefaultDialect() { return "transform"; }
- BlockArgument getIterationVariable() {
- return getBody().front().getArgument(0);
- }
-
transform::YieldOp getYieldOp();
}];
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 247759e..1a7ec03 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1391,46 +1391,83 @@ DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
- // Store payload ops in a vector because ops may be removed from the mapping
- // by the TrackingRewriter while the iteration is in progress.
- SmallVector<Operation *> targets =
- llvm::to_vector(state.getPayloadOps(getTarget()));
- for (Operation *op : targets) {
+ // We store the payloads before executing the body as ops may be removed from
+ // the mapping by the TrackingRewriter while iteration is in progress.
+ SmallVector<SmallVector<MappedValue>> payloads;
+ detail::prepareValueMappings(payloads, getTargets(), state);
+ size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
+
+ // As we will be "zipping" over them, check all payloads have the same size.
+ for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) {
+ if (payloads[argIdx].size() != numIterations) {
+ return emitSilenceableError()
+ << "prior targets' payload size (" << numIterations
+ << ") differs from payload size (" << payloads[argIdx].size()
+ << ") of target " << getTargets()[argIdx];
+ }
+ }
+
+ // Start iterating, indexing into payloads to obtain the right arguments to
+ // call the body with - each slice of payloads at the same argument index
+ // corresponding to a tuple to use as the body's block arguments.
+ ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
+ SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
+ for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
auto scope = state.make_region_scope(getBody());
- if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
- return DiagnosedSilenceableFailure::definiteFailure();
+ // Set up arguments to the region's block.
+ for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
+ MappedValue argument = payloads[argIdx][iterIdx];
+ // Note that each blockArg's handle gets associated with just a single
+ // element from the corresponding target's payload.
+ if (failed(state.mapBlockArgument(blockArg, {argument})))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
// Execute loop body.
for (Operation &transform : getBody().front().without_terminator()) {
DiagnosedSilenceableFailure result = state.applyTransform(
- cast<transform::TransformOpInterface>(transform));
+ llvm::cast<transform::TransformOpInterface>(transform));
if (!result.succeeded())
return result;
}
- // Append yielded payload ops to result list (if any).
- for (unsigned i = 0; i < getNumResults(); ++i) {
- auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
- resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
- }
- }
-
- for (unsigned i = 0; i < getNumResults(); ++i)
- results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
+ // Append yielded payloads to corresponding results from prior iterations.
+ OperandRange yieldOperands = getYieldOp().getOperands();
+ for (auto &&[result, yieldOperand, resTuple] :
+ llvm::zip_equal(getResults(), yieldOperands, zippedResults))
+ // NB: each iteration we add any number of ops/vals/params to a result.
+ if (isa<TransformHandleTypeInterface>(result.getType()))
+ llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
+ else if (isa<TransformValueHandleTypeInterface>(result.getType()))
+ llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
+ else if (isa<TransformParamTypeInterface>(result.getType()))
+ llvm::append_range(resTuple, state.getParams(yieldOperand));
+ else
+ assert(false && "unhandled handle type");
+ }
+
+ // Associate the accumulated result payloads to the op's actual results.
+ for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
+ results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
return DiagnosedSilenceableFailure::success();
}
void transform::ForeachOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- BlockArgument iterVar = getIterationVariable();
- if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
- return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
- })) {
- consumesHandle(getTarget(), effects);
- } else {
- onlyReadsHandle(getTarget(), effects);
+ // NB: this `zip` should be `zip_equal` - while this op's verifier catches
+ // arity errors, this method might get called before/in absence of `verify()`.
+ for (auto &&[target, blockArg] :
+ llvm::zip(getTargets(), getBody().front().getArguments())) {
+ BlockArgument blockArgument = blockArg;
+ if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+ return isHandleConsumed(blockArgument,
+ cast<TransformOpInterface>(&op));
+ })) {
+ consumesHandle(target, effects);
+ } else {
+ onlyReadsHandle(target, effects);
+ }
}
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
@@ -1463,8 +1500,8 @@ void transform::ForeachOp::getSuccessorRegions(
OperandRange
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- // The iteration variable op handle is mapped to a subset (one op to be
- // precise) of the payload ops of the ForeachOp operand.
+ // Each block argument handle is mapped to a subset (one op to be precise)
+ // of the payload of the corresponding `targets` operand of ForeachOp.
assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}
@@ -1474,14 +1511,27 @@ transform::YieldOp transform::ForeachOp::getYieldOp() {
}
LogicalResult transform::ForeachOp::verify() {
- auto yieldOp = getYieldOp();
- if (getNumResults() != yieldOp.getNumOperands())
- return emitOpError() << "expects the same number of results as the "
- "terminator has operands";
- for (Value v : yieldOp.getOperands())
- if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
- return yieldOp->emitOpError("expects operands to have types implementing "
- "TransformHandleTypeInterface");
+ for (auto [targetOpt, bodyArgOpt] :
+ llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
+ if (!targetOpt || !bodyArgOpt)
+ return emitOpError() << "expects the same number of targets as the body "
+ "has block arguments";
+ if (targetOpt.value().getType() != bodyArgOpt.value().getType())
+ return emitOpError(
+ "expects co-indexed targets and the body's "
+ "block arguments to have the same op/value/param type");
+ }
+
+ for (auto [resultOpt, yieldOperandOpt] :
+ llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
+ if (!resultOpt || !yieldOperandOpt)
+ return emitOpError() << "expects the same number of results as the "
+ "yield terminator has operands";
+ if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
+ return emitOpError("expects co-indexed results and yield "
+ "operands to have the same op/value/param type");
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index 15b24b5..51332ff 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
- %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
- %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
- %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
- transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.foreach %5 : !transform.any_op {
+ ^bb0(%inner_linalg: !transform.any_op):
+ %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
+ %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
+ transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ }
transform.yield
}
}
@@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} {
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
- %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
- transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
- transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
+ ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
+ %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
+ transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ }
transform.yield
}
}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1c..54dd2bd 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -328,3 +328,76 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+// -----
+
+// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @foreach_loop_pair_fuse(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) {
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+ // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+ %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+ // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+ %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %5 = arith.addf %3, %2 : vector<16xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ } {target_loops}
+ %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+ %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+ %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+ scf.yield %dup6 : tensor<128xf32>
+ } {source_loops}
+ %2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+ // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+ %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+ %5 = arith.addf %3, %2 : vector<32xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ } {target_loops}
+ %dup2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+ %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+ %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+ %dup5 = arith.addf %dup3, %dup2 : vector<32xf32>
+ %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+ // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+ scf.yield %dup6 : tensor<128xf32>
+ } {source_loops}
+ return %1, %dup1, %2, %dup2 : tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>
+}
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %target_loops = transform.structured.match ops{["scf.for"]} attributes {target_loops} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %source_loops = transform.structured.match ops{["scf.for"]} attributes {source_loops} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.foreach %target_loops, %source_loops : !transform.any_op, !transform.any_op {
+ ^bb0(%target_loop: !transform.any_op, %source_loop: !transform.any_op):
+ %fused = transform.loop.fuse_sibling %target_loop into %source_loop : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ }
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 30a68cc..71a260f 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -279,6 +279,55 @@ transform.sequence failures(propagate) {
// -----
+transform.sequence failures(propagate) {
+ ^bb0(%root: !transform.any_op):
+ %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ // expected-error @below {{op expects the same number of targets as the body has block arguments}}
+ transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_value {
+ ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value):
+ transform.yield %op_arg, %val_arg : !transform.any_op, !transform.any_value
+ }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+ ^bb0(%root: !transform.any_op):
+ %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ // expected-error @below {{op expects co-indexed targets and the body's block arguments to have the same op/value/param type}}
+ transform.foreach %op : !transform.any_op -> !transform.any_value {
+ ^bb1(%val_arg: !transform.any_value):
+ transform.yield %val_arg : !transform.any_value
+ }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+ ^bb0(%root: !transform.any_op):
+ %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ // expected-error @below {{op expects the same number of results as the yield terminator has operands}}
+ transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_op {
+ ^bb1(%arg_op: !transform.any_op):
+ transform.yield %arg_op : !transform.any_op
+ }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+ ^bb0(%root: !transform.any_op):
+ %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ %val = transform.test_produce_value_handle_to_self_operand %op : (!transform.any_op) -> !transform.any_value
+ // expected-error @below {{expects co-indexed results and yield operands to have the same op/value/param type}}
+ transform.foreach %op, %val : !transform.any_op, !transform.any_value -> !transform.any_op, !transform.any_value {
+ ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value):
+ transform.yield %val_arg, %op_arg : !transform.any_value, !transform.any_op
+ }
+}
+
+// -----
+
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index b03a9f4..e9baffd 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -68,11 +68,25 @@ transform.sequence failures(propagate) {
}
// CHECK: transform.sequence
-// CHECK: foreach
transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op):
- transform.foreach %arg0 : !transform.any_op {
- ^bb1(%arg1: !transform.any_op):
+^bb0(%op0: !transform.any_op, %val0: !transform.any_value, %par0: !transform.any_param):
+ // CHECK: foreach %{{.*}} : !transform.any_op
+ transform.foreach %op0 : !transform.any_op {
+ ^bb1(%op1: !transform.any_op):
+ }
+ // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param
+ transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param {
+ ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+ }
+ // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op
+ transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op {
+ ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+ transform.yield %op1 : !transform.any_op
+ }
+ // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value
+ transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value {
+ ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+ transform.yield %par1, %val1 : !transform.any_param, !transform.any_value
}
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b6850e2..4fe2dbe 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -830,6 +830,91 @@ module attributes {transform.with_named_sequence} {
// -----
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %results, %types = transform.foreach %0 : !transform.any_op -> !transform.any_value, !transform.any_param {
+ ^bb0(%op0 : !transform.any_op):
+ %result = transform.get_result %op0[0] : (!transform.any_op) -> !transform.any_value
+ %type = transform.get_type elemental %result : (!transform.any_value) -> !transform.any_param
+ transform.yield %result, %type : !transform.any_value, !transform.any_param
+ }
+ transform.debug.emit_remark_at %results, "result selected" : !transform.any_value
+ transform.debug.emit_param_as_remark %types, "elemental types" at %0 : !transform.any_param, !transform.any_op
+
+ transform.yield
+ }
+}
+
+func.func @payload(%lhs: tensor<10x20xf16>,
+ %rhs: tensor<20x15xf32>) -> (tensor<10x15xf64>, tensor<10x15xf32>) {
+ %cst64 = arith.constant 0.0 : f64
+ %empty64 = tensor.empty() : tensor<10x15xf64>
+ %fill64 = linalg.fill ins(%cst64 : f64) outs(%empty64 : tensor<10x15xf64>) -> tensor<10x15xf64>
+ // expected-remark @below {{result selected}}
+ // expected-note @below {{value handle points to an op result #0}}
+ // expected-remark @below {{elemental types f64, f32}}
+ %result64 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>)
+ outs(%fill64: tensor<10x15xf64>) -> tensor<10x15xf64>
+
+ %cst32 = arith.constant 0.0 : f32
+ %empty32 = tensor.empty() : tensor<10x15xf32>
+ %fill32 = linalg.fill ins(%cst32 : f32) outs(%empty32 : tensor<10x15xf32>) -> tensor<10x15xf32>
+ // expected-remark @below {{result selected}}
+ // expected-note @below {{value handle points to an op result #0}}
+ // expected-remark @below {{elemental types f64, f32}}
+ %result32 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>)
+ outs(%fill32: tensor<10x15xf32>) -> tensor<10x15xf32>
+
+ return %result64, %result32 : tensor<10x15xf64>, tensor<10x15xf32>
+
+}
+
+// -----
+
+func.func @two_const_ops() {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %two_ops = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %one_param = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+ // expected-error @below {{prior targets' payload size (2) differs from payload size (1) of target}}
+ transform.foreach %two_ops, %one_param : !transform.any_op, !transform.test_dialect_param {
+ ^bb2(%op: !transform.any_op, %param: !transform.test_dialect_param):
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @one_const_op() {
+ %0 = arith.constant 0 : index
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %one_op = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %one_val = transform.test_produce_value_handle_to_self_operand %one_op : (!transform.any_op) -> !transform.any_value
+ %param_one = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+ %param_two = transform.param.constant 2 : i32 -> !transform.test_dialect_param
+ %two_params = transform.merge_handles %param_one, %param_two : !transform.test_dialect_param
+
+ // expected-error @below {{prior targets' payload size (1) differs from payload size (2) of target}}
+ transform.foreach %one_val, %one_op, %two_params : !transform.any_value, !transform.any_op, !transform.test_dialect_param {
+ ^bb2(%val: !transform.any_value, %op: !transform.any_op, %param: !transform.test_dialect_param):
+ }
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @consume_in_foreach()
// CHECK-NEXT: return
func.func @consume_in_foreach() {