diff options
-rw-r--r-- | mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 57 | ||||
-rw-r--r-- | mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 120 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/multisize-tiling-full.mlir | 21 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir | 73 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/ops-invalid.mlir | 49 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/ops.mlir | 22 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/test-interpreter.mlir | 85 |
7 files changed, 354 insertions, 73 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 77048a2..3bb297c 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -614,43 +614,48 @@ def ForeachOp : TransformDialectOp<"foreach", "getSuccessorRegions", "getEntrySuccessorOperands"]>, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp"> ]> { - let summary = "Executes the body for each payload op"; + let summary = "Executes the body for each element of the payload"; let description = [{ - This op has exactly one region with exactly one block ("body"). The body is - executed for each payload op that is associated to the target operand in an - unbatched fashion. I.e., the block argument ("iteration variable") is always - mapped to exactly one payload op. - - This op always reads the target handle. Furthermore, it consumes the handle - if there is a transform op in the body that consumes the iteration variable. - This op does not return anything. - - The transformations inside the body are applied in order of their - appearance. During application, if any transformation in the sequence fails, - the entire sequence fails immediately leaving the payload IR in potentially - invalid state, i.e., this operation offers no transformation rollback - capabilities. - - This op generates as many handles as the terminating YieldOp has operands. - For each result, the payload ops of the corresponding YieldOp operand are - merged and mapped to the same resulting handle. + Execute the op's body - its single region block - exactly once per + element of the payload associated to a target handle. The body's + transformations are applied in order of appearance until reaching the + (implicit) YieldOp terminator. + + Each iteration gets executed by co-indexing the payloads of the arguments + and mapping the body's arguments to these tuples, as though iterating over + the zipped together `targets`. As such, in each iteration, the size of the + payload of each of the body's block arguments is exactly one. + + This op always reads the target handles. Furthermore, it consumes a handle + if there is a transform op in the body that consumes the corresponding + block argument. Handles can point to ops, values, or parameters. + + #### Return Modes + + This op produces as many result handles as the body's terminating YieldOp + has operands. For each result, the payloads of the corresponding YieldOp + operand are merged and mapped to the same resulting handle. + + If the target handles do not associate payloads of the same size, a + silencable failure will be generated. + + During application, if any transformation in the sequence fails, the entire + sequence fails immediately with the same failure, leaving the payload IR in + a potentially invalid state, i.e., this operation offers no transformation + rollback capabilities. }]; - let arguments = (ins TransformHandleTypeInterface:$target); - let results = (outs Variadic<TransformHandleTypeInterface>:$results); + let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets); + let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results); let regions = (region SizedRegion<1>:$body); let assemblyFormat = - "$target `:` type($target) (`->` type($results)^)? $body attr-dict"; + "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict"; let hasVerifier = 1; let extraClassDeclaration = [{ /// Allow the dialect prefix to be omitted. static StringRef getDefaultDialect() { return "transform"; } - BlockArgument getIterationVariable() { - return getBody().front().getArgument(0); - } - transform::YieldOp getYieldOp(); }]; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 247759e..1a7ec03 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1391,46 +1391,83 @@ DiagnosedSilenceableFailure transform::ForeachOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {}); - // Store payload ops in a vector because ops may be removed from the mapping - // by the TrackingRewriter while the iteration is in progress. - SmallVector<Operation *> targets = - llvm::to_vector(state.getPayloadOps(getTarget())); - for (Operation *op : targets) { + // We store the payloads before executing the body as ops may be removed from + // the mapping by the TrackingRewriter while iteration is in progress. + SmallVector<SmallVector<MappedValue>> payloads; + detail::prepareValueMappings(payloads, getTargets(), state); + size_t numIterations = payloads.empty() ? 0 : payloads.front().size(); + + // As we will be "zipping" over them, check all payloads have the same size. + for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) { + if (payloads[argIdx].size() != numIterations) { + return emitSilenceableError() + << "prior targets' payload size (" << numIterations + << ") differs from payload size (" << payloads[argIdx].size() + << ") of target " << getTargets()[argIdx]; + } + } + + // Start iterating, indexing into payloads to obtain the right arguments to + // call the body with - each slice of payloads at the same argument index + // corresponding to a tuple to use as the body's block arguments. + ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments(); + SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {}); + for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) { auto scope = state.make_region_scope(getBody()); - if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) - return DiagnosedSilenceableFailure::definiteFailure(); + // Set up arguments to the region's block. + for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) { + MappedValue argument = payloads[argIdx][iterIdx]; + // Note that each blockArg's handle gets associated with just a single + // element from the corresponding target's payload. + if (failed(state.mapBlockArgument(blockArg, {argument}))) + return DiagnosedSilenceableFailure::definiteFailure(); + } // Execute loop body. for (Operation &transform : getBody().front().without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform( - cast<transform::TransformOpInterface>(transform)); + llvm::cast<transform::TransformOpInterface>(transform)); if (!result.succeeded()) return result; } - // Append yielded payload ops to result list (if any). - for (unsigned i = 0; i < getNumResults(); ++i) { - auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i)); - resultOps[i].append(yieldedOps.begin(), yieldedOps.end()); - } - } - - for (unsigned i = 0; i < getNumResults(); ++i) - results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]); + // Append yielded payloads to corresponding results from prior iterations. + OperandRange yieldOperands = getYieldOp().getOperands(); + for (auto &&[result, yieldOperand, resTuple] : + llvm::zip_equal(getResults(), yieldOperands, zippedResults)) + // NB: each iteration we add any number of ops/vals/params to a result. + if (isa<TransformHandleTypeInterface>(result.getType())) + llvm::append_range(resTuple, state.getPayloadOps(yieldOperand)); + else if (isa<TransformValueHandleTypeInterface>(result.getType())) + llvm::append_range(resTuple, state.getPayloadValues(yieldOperand)); + else if (isa<TransformParamTypeInterface>(result.getType())) + llvm::append_range(resTuple, state.getParams(yieldOperand)); + else + assert(false && "unhandled handle type"); + } + + // Associate the accumulated result payloads to the op's actual results. + for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults)) + results.setMappedValues(llvm::cast<OpResult>(result), resPayload); return DiagnosedSilenceableFailure::success(); } void transform::ForeachOp::getEffects( SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { - BlockArgument iterVar = getIterationVariable(); - if (any_of(getBody().front().without_terminator(), [&](Operation &op) { - return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op)); - })) { - consumesHandle(getTarget(), effects); - } else { - onlyReadsHandle(getTarget(), effects); + // NB: this `zip` should be `zip_equal` - while this op's verifier catches + // arity errors, this method might get called before/in absence of `verify()`. + for (auto &&[target, blockArg] : + llvm::zip(getTargets(), getBody().front().getArguments())) { + BlockArgument blockArgument = blockArg; + if (any_of(getBody().front().without_terminator(), [&](Operation &op) { + return isHandleConsumed(blockArgument, + cast<TransformOpInterface>(&op)); + })) { + consumesHandle(target, effects); + } else { + onlyReadsHandle(target, effects); + } } if (any_of(getBody().front().without_terminator(), [&](Operation &op) { @@ -1463,8 +1500,8 @@ void transform::ForeachOp::getSuccessorRegions( OperandRange transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { - // The iteration variable op handle is mapped to a subset (one op to be - // precise) of the payload ops of the ForeachOp operand. + // Each block argument handle is mapped to a subset (one op to be precise) + // of the payload of the corresponding `targets` operand of ForeachOp. assert(point == getBody() && "unexpected region index"); return getOperation()->getOperands(); } @@ -1474,14 +1511,27 @@ transform::YieldOp transform::ForeachOp::getYieldOp() { } LogicalResult transform::ForeachOp::verify() { - auto yieldOp = getYieldOp(); - if (getNumResults() != yieldOp.getNumOperands()) - return emitOpError() << "expects the same number of results as the " - "terminator has operands"; - for (Value v : yieldOp.getOperands()) - if (!llvm::isa<TransformHandleTypeInterface>(v.getType())) - return yieldOp->emitOpError("expects operands to have types implementing " - "TransformHandleTypeInterface"); + for (auto [targetOpt, bodyArgOpt] : + llvm::zip_longest(getTargets(), getBody().front().getArguments())) { + if (!targetOpt || !bodyArgOpt) + return emitOpError() << "expects the same number of targets as the body " + "has block arguments"; + if (targetOpt.value().getType() != bodyArgOpt.value().getType()) + return emitOpError( + "expects co-indexed targets and the body's " + "block arguments to have the same op/value/param type"); + } + + for (auto [resultOpt, yieldOperandOpt] : + llvm::zip_longest(getResults(), getYieldOp().getOperands())) { + if (!resultOpt || !yieldOperandOpt) + return emitOpError() << "expects the same number of results as the " + "yield terminator has operands"; + if (resultOpt.value().getType() != yieldOperandOpt.value().getType()) + return emitOpError("expects co-indexed results and yield " + "operands to have the same op/value/param type"); + } + return success(); } diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir index 15b24b5..51332ff 100644 --- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir @@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op - %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op - %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op - %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op - transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.foreach %5 : !transform.any_op { + ^bb0(%inner_linalg: !transform.any_op): + %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op + %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op + transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + } transform.yield } } @@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} { %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> - %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64> - transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) - transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) + transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> { + ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>): + %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64> + transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) + transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) + } transform.yield } } diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 0f51b1c..54dd2bd 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -328,3 +328,76 @@ module attributes {transform.with_named_sequence} { transform.yield } } +// ----- + +// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +func.func @foreach_loop_pair_fuse(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) { + // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index + // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index + // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index + // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}} + %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) { + // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]] + // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]] + // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]] + // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]] + %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %5 = arith.addf %3, %2 : vector<16xf32> + %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + scf.yield %6 : tensor<128xf32> + } {target_loops} + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) { + // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]] + // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]] + // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]] + %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> + %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}} + scf.yield %dup6 : tensor<128xf32> + } {source_loops} + %2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) { + // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]] + // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]] + // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]] + // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]] + %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32> + %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32> + %5 = arith.addf %3, %2 : vector<32xf32> + %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32> + scf.yield %6 : tensor<128xf32> + } {target_loops} + %dup2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) { + // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]] + // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]] + // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]] + %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32> + %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32> + %dup5 = arith.addf %dup3, %dup2 : vector<32xf32> + %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32> + // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}} + scf.yield %dup6 : tensor<128xf32> + } {source_loops} + return %1, %dup1, %2, %dup2 : tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32> +} + + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %target_loops = transform.structured.match ops{["scf.for"]} attributes {target_loops} in %arg0 : (!transform.any_op) -> !transform.any_op + %source_loops = transform.structured.match ops{["scf.for"]} attributes {source_loops} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.foreach %target_loops, %source_loops : !transform.any_op, !transform.any_op { + ^bb0(%target_loop: !transform.any_op, %source_loop: !transform.any_op): + %fused = transform.loop.fuse_sibling %target_loop into %source_loop : (!transform.any_op,!transform.any_op) -> !transform.any_op + } + transform.yield + } +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 30a68cc..71a260f 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -279,6 +279,55 @@ transform.sequence failures(propagate) { // ----- +transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op + // expected-error @below {{op expects the same number of targets as the body has block arguments}} + transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_value { + ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value): + transform.yield %op_arg, %val_arg : !transform.any_op, !transform.any_value + } +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op + // expected-error @below {{op expects co-indexed targets and the body's block arguments to have the same op/value/param type}} + transform.foreach %op : !transform.any_op -> !transform.any_value { + ^bb1(%val_arg: !transform.any_value): + transform.yield %val_arg : !transform.any_value + } +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op + // expected-error @below {{op expects the same number of results as the yield terminator has operands}} + transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_op { + ^bb1(%arg_op: !transform.any_op): + transform.yield %arg_op : !transform.any_op + } +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op + %val = transform.test_produce_value_handle_to_self_operand %op : (!transform.any_op) -> !transform.any_value + // expected-error @below {{expects co-indexed results and yield operands to have the same op/value/param type}} + transform.foreach %op, %val : !transform.any_op, !transform.any_value -> !transform.any_op, !transform.any_value { + ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value): + transform.yield %val_arg, %op_arg : !transform.any_value, !transform.any_op + } +} + +// ----- + transform.sequence failures(suppress) { ^bb0(%arg0: !transform.any_op): // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}} diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir index b03a9f4..e9baffd 100644 --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -68,11 +68,25 @@ transform.sequence failures(propagate) { } // CHECK: transform.sequence -// CHECK: foreach transform.sequence failures(propagate) { -^bb0(%arg0: !transform.any_op): - transform.foreach %arg0 : !transform.any_op { - ^bb1(%arg1: !transform.any_op): +^bb0(%op0: !transform.any_op, %val0: !transform.any_value, %par0: !transform.any_param): + // CHECK: foreach %{{.*}} : !transform.any_op + transform.foreach %op0 : !transform.any_op { + ^bb1(%op1: !transform.any_op): + } + // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param + transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param { + ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param): + } + // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op + transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op { + ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param): + transform.yield %op1 : !transform.any_op + } + // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value + transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value { + ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param): + transform.yield %par1, %val1 : !transform.any_param, !transform.any_value } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index b6850e2..4fe2dbe 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -830,6 +830,91 @@ module attributes {transform.with_named_sequence} { // ----- +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %results, %types = transform.foreach %0 : !transform.any_op -> !transform.any_value, !transform.any_param { + ^bb0(%op0 : !transform.any_op): + %result = transform.get_result %op0[0] : (!transform.any_op) -> !transform.any_value + %type = transform.get_type elemental %result : (!transform.any_value) -> !transform.any_param + transform.yield %result, %type : !transform.any_value, !transform.any_param + } + transform.debug.emit_remark_at %results, "result selected" : !transform.any_value + transform.debug.emit_param_as_remark %types, "elemental types" at %0 : !transform.any_param, !transform.any_op + + transform.yield + } +} + +func.func @payload(%lhs: tensor<10x20xf16>, + %rhs: tensor<20x15xf32>) -> (tensor<10x15xf64>, tensor<10x15xf32>) { + %cst64 = arith.constant 0.0 : f64 + %empty64 = tensor.empty() : tensor<10x15xf64> + %fill64 = linalg.fill ins(%cst64 : f64) outs(%empty64 : tensor<10x15xf64>) -> tensor<10x15xf64> + // expected-remark @below {{result selected}} + // expected-note @below {{value handle points to an op result #0}} + // expected-remark @below {{elemental types f64, f32}} + %result64 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) + outs(%fill64: tensor<10x15xf64>) -> tensor<10x15xf64> + + %cst32 = arith.constant 0.0 : f32 + %empty32 = tensor.empty() : tensor<10x15xf32> + %fill32 = linalg.fill ins(%cst32 : f32) outs(%empty32 : tensor<10x15xf32>) -> tensor<10x15xf32> + // expected-remark @below {{result selected}} + // expected-note @below {{value handle points to an op result #0}} + // expected-remark @below {{elemental types f64, f32}} + %result32 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) + outs(%fill32: tensor<10x15xf32>) -> tensor<10x15xf32> + + return %result64, %result32 : tensor<10x15xf64>, tensor<10x15xf32> + +} + +// ----- + +func.func @two_const_ops() { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %two_ops = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %one_param = transform.param.constant 1 : i32 -> !transform.test_dialect_param + // expected-error @below {{prior targets' payload size (2) differs from payload size (1) of target}} + transform.foreach %two_ops, %one_param : !transform.any_op, !transform.test_dialect_param { + ^bb2(%op: !transform.any_op, %param: !transform.test_dialect_param): + } + transform.yield + } +} + +// ----- + +func.func @one_const_op() { + %0 = arith.constant 0 : index + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %one_op = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %one_val = transform.test_produce_value_handle_to_self_operand %one_op : (!transform.any_op) -> !transform.any_value + %param_one = transform.param.constant 1 : i32 -> !transform.test_dialect_param + %param_two = transform.param.constant 2 : i32 -> !transform.test_dialect_param + %two_params = transform.merge_handles %param_one, %param_two : !transform.test_dialect_param + + // expected-error @below {{prior targets' payload size (1) differs from payload size (2) of target}} + transform.foreach %one_val, %one_op, %two_params : !transform.any_value, !transform.any_op, !transform.test_dialect_param { + ^bb2(%val: !transform.any_value, %op: !transform.any_op, %param: !transform.test_dialect_param): + } + transform.yield + } +} + +// ----- + // CHECK-LABEL: func @consume_in_foreach() // CHECK-NEXT: return func.func @consume_in_foreach() { |