diff options
-rw-r--r-- | mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td | 23 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 10 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp | 66 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Utils/Utils.cpp | 99 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir | 251 |
5 files changed, 357 insertions, 92 deletions
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index 6f94cee..5eefe26 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -333,23 +333,24 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [ }]; } -def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling", +def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling", [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, DeclareOpInterfaceMethods<TransformOpInterface>]> { let summary = "Fuse a loop into another loop, assuming the fusion is legal."; let description = [{ Fuses the `target` loop into the `source` loop assuming they are - independent of each other. It is the responsibility of the user to ensure - that the given two loops are independent of each other, this operation will - not performa any legality checks and will simply fuse the two given loops. + independent of each other. In the fused loop, the arguments, body and + results of `target` are placed _before_ those of `source`. - Currently, the only fusion supported is when both `target` and `source` - are `scf.forall` operations. For `scf.forall` fusion, the bounds and the - mapping must match, otherwise a silencable failure is produced. + For fusion of two `scf.for` loops, the bounds and step size must match. For + fusion of two `scf.forall` loops, the bounds and the mapping must match. + Otherwise a silencable failure is produced. - The input handles `target` and `source` must map to exactly one operation, - a definite failure is produced otherwise. + The `target` and `source` handles must refer to exactly one operation, + otherwise a definite failure is produced. It is the responsibility of the + user to ensure that the `target` and `source` loops are independent of each + other -- this op will only perform rudimentary legality checks. #### Return modes @@ -362,10 +363,6 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling", let results = (outs TransformHandleTypeInterface:$fused_loop); let assemblyFormat = "$target `into` $source attr-dict " " `:` functional-type(operands, results)"; - - let builders = [ - OpBuilder<(ins "Value":$loop, "Value":$fused_loop)> - ]; } #endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 9bdd6eb..883d11b 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter); +/// Given two scf.for loops, `target` and `source`, fuses `target` into +/// `source`. Assumes that the given loops are siblings and are independent of +/// each other. +/// +/// This function does not perform any legality checks and simply fuses the +/// loops. The caller is responsible for ensuring that the loops are legal to +/// fuse. +scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, + RewriterBase &rewriter); + } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 4d8d93f..c091841 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects( } //===----------------------------------------------------------------------===// -// LoopFuseSibling +// LoopFuseSiblingOp //===----------------------------------------------------------------------===// /// Check if `target` and `source` are siblings, in the context that `target` @@ -408,7 +408,7 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target, // Check if fusion will violate dominance. DominanceInfo domInfo(source); if (target->isBeforeInBlock(source)) { - // Since, `target` is before `source`, all users of results of `target` + // Since `target` is before `source`, all users of results of `target` // need to be dominated by `source`. for (Operation *user : target->getUsers()) { if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { @@ -424,9 +424,8 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target, // Check if operands of `target` are dominated by `source`. for (Value operand : target->getOperands()) { Operation *operandOp = operand.getDefiningOp(); - // If operand does not have a defining operation, it is a block arguement, - // which will always dominate `source`, since `target` and `source` are in - // the same block and the operand dominated `source` before. + // Operands without defining operations are block arguments. When `target` + // and `source` occur in the same block, these operands dominate `source`. if (!operandOp) continue; @@ -441,8 +440,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target, bool failed = false; OpOperand *failedValue = nullptr; visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { - if (!domInfo.properlyDominates(operand->getOwner(), source, - /*enclosingOpOk=*/false)) { + Operation *operandOp = operand->get().getDefiningOp(); + if (operandOp && !domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + // `operand` is not an argument of an enclosing block and the defining + // op of `operand` is outside `target` but does not dominate `source`. failed = true; failedValue = operand; } @@ -457,12 +459,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target, return DiagnosedSilenceableFailure::success(); } -/// Check if `target` can be fused into `source`. +/// Check if `target` scf.forall can be fused into `source` scf.forall. /// -/// This is a simple check that just checks if both loops have same -/// bounds, steps and mapping. This check does not ensure that the side effects -/// of `target` are independent of `source` or vice-versa. It is the -/// responsibility of the caller to ensure that. +/// This simply checks if both loops have the same bounds, steps and mapping. +/// No attempt is made at checking that the side effects of `target` and +/// `source` are independent of each other. static bool isForallWithIdenticalConfiguration(Operation *target, Operation *source) { auto targetOp = dyn_cast<scf::ForallOp>(target); @@ -476,21 +477,27 @@ static bool isForallWithIdenticalConfiguration(Operation *target, targetOp.getMapping() == sourceOp.getMapping(); } -/// Fuse `target` into `source` assuming they are siblings and indepndent. -/// TODO: Add fusion for more operations. Currently, we handle only scf.forall. -static Operation *fuseSiblings(Operation *target, Operation *source, - RewriterBase &rewriter) { - auto targetOp = dyn_cast<scf::ForallOp>(target); - auto sourceOp = dyn_cast<scf::ForallOp>(source); +/// Check if `target` scf.for can be fused into `source` scf.for. +/// +/// This simply checks if both loops have the same bounds and steps. No attempt +/// is made at checking that the side effects of `target` and `source` are +/// independent of each other. +static bool isForWithIdenticalConfiguration(Operation *target, + Operation *source) { + auto targetOp = dyn_cast<scf::ForOp>(target); + auto sourceOp = dyn_cast<scf::ForOp>(source); if (!targetOp || !sourceOp) - return nullptr; - return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter); + return false; + + return targetOp.getLowerBound() == sourceOp.getLowerBound() && + targetOp.getUpperBound() == sourceOp.getUpperBound() && + targetOp.getStep() == sourceOp.getStep(); } DiagnosedSilenceableFailure -transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter, - transform::TransformResults &results, - transform::TransformState &state) { +transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); auto sourceOps = state.getPayloadOps(getSource()); @@ -510,13 +517,18 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter, if (!diag.succeeded()) return diag; - // Check if the target can be fused into source. - if (!isForallWithIdenticalConfiguration(target, source)) { + Operation *fusedLoop; + /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. + if (isForWithIdenticalConfiguration(target, source)) { + fusedLoop = fuseIndependentSiblingForLoops( + cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter); + } else if (isForallWithIdenticalConfiguration(target, source)) { + fusedLoop = fuseIndependentSiblingForallLoops( + cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter); + } else return emitSilenceableFailure(target->getLoc()) << "operations cannot be fused"; - } - Operation *fusedLoop = fuseSiblings(target, source, rewriter); assert(fusedLoop && "failed to fuse operations"); results.set(cast<OpResult>(getFusedLoop()), {fusedLoop}); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 502d7e1..914aeb4 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, unsigned numTargetOuts = target.getNumResults(); unsigned numSourceOuts = source.getNumResults(); - OperandRange targetOuts = target.getOutputs(); - OperandRange sourceOuts = source.getOutputs(); - // Create fused shared_outs. SmallVector<Value> fusedOuts; - fusedOuts.reserve(numTargetOuts + numSourceOuts); - fusedOuts.append(targetOuts.begin(), targetOuts.end()); - fusedOuts.append(sourceOuts.begin(), sourceOuts.end()); + llvm::append_range(fusedOuts, target.getOutputs()); + llvm::append_range(fusedOuts, source.getOutputs()); - // Create a new scf::forall op after the source loop. + // Create a new scf.forall op after the source loop. rewriter.setInsertionPointAfter(source); scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>( source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), source.getMixedStep(), fusedOuts, source.getMapping()); // Map control operands. - IRMapping fusedMapping; - fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); - fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); + IRMapping mapping; + mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); + mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); // Map shared outs. - fusedMapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().slice(0, numTargetOuts)); - fusedMapping.map( - source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts)); + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); // Append everything except the terminator into the fused operation. rewriter.setInsertionPointToStart(fusedLoop.getBody()); for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, fusedMapping); + rewriter.clone(op, mapping); for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, fusedMapping); + rewriter.clone(op, mapping); // Fuse the old terminator in_parallel ops into the new one. scf::InParallelOp targetTerm = target.getTerminator(); scf::InParallelOp sourceTerm = source.getTerminator(); scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); - rewriter.setInsertionPointToStart(fusedTerm.getBody()); for (Operation &op : targetTerm.getYieldingOps()) - rewriter.clone(op, fusedMapping); + rewriter.clone(op, mapping); for (Operation &op : sourceTerm.getYieldingOps()) - rewriter.clone(op, fusedMapping); - - // Replace all uses of the old loops with the fused loop. - rewriter.replaceAllUsesWith(target.getResults(), - fusedLoop.getResults().slice(0, numTargetOuts)); - rewriter.replaceAllUsesWith( - source.getResults(), - fusedLoop.getResults().slice(numTargetOuts, numSourceOuts)); - - // Erase the old loops. - rewriter.eraseOp(target); - rewriter.eraseOp(source); + rewriter.clone(op, mapping); + + // Replace old loops by substituting their uses by results of the fused loop. + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + + return fusedLoop; +} + +scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, + scf::ForOp source, + RewriterBase &rewriter) { + unsigned numTargetOuts = target.getNumResults(); + unsigned numSourceOuts = source.getNumResults(); + + // Create fused init_args, with target's init_args before source's init_args. + SmallVector<Value> fusedInitArgs; + llvm::append_range(fusedInitArgs, target.getInitArgs()); + llvm::append_range(fusedInitArgs, source.getInitArgs()); + + // Create a new scf.for op after the source loop (with scf.yield terminator + // (without arguments) only in case its init_args is empty). + rewriter.setInsertionPointAfter(source); + scf::ForOp fusedLoop = rewriter.create<scf::ForOp>( + source.getLoc(), source.getLowerBound(), source.getUpperBound(), + source.getStep(), fusedInitArgs); + + // Map original induction variables and operands to those of the fused loop. + IRMapping mapping; + mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); + + // Merge target's body into the new (fused) for loop and then source's body. + rewriter.setInsertionPointToStart(fusedLoop.getBody()); + for (Operation &op : target.getBody()->without_terminator()) + rewriter.clone(op, mapping); + for (Operation &op : source.getBody()->without_terminator()) + rewriter.clone(op, mapping); + + // Build fused yield results by appropriately mapping original yield operands. + SmallVector<Value> yieldResults; + for (Value operand : target.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + for (Value operand : source.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + if (!yieldResults.empty()) + rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults); + + // Replace old loops by substituting their uses by results of the fused loop. + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); return fusedLoop; } diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index faaa2db..0f51b1c 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -1,14 +1,113 @@ // RUN: mlir-opt %s -transform-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s --check-prefix CHECK-NOCLEANUP -func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { +// CHECK: func.func @fuse_1st_for_into_2nd([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +func.func @fuse_1st_for_into_2nd(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { + // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index + // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index + // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index + // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IA:%.*]] = [[A]], [[IB:%.*]] = [[B]]) {{.*}} + %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) { + // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]] + // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]] + // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]] + // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IA]][[[IV]]] + %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %5 = arith.addf %3, %2 : vector<16xf32> + %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + scf.yield %6 : tensor<128xf32> + } + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { + // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]] + // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]] + // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB]][[[IV]]] + %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> + %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}} + scf.yield %dup6 : tensor<128xf32> + } + return %1, %dup1 : tensor<128xf32>, tensor<128xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { + // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index + // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index + // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index + // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB:%.*]] = [[B]], [[IA:%.*]] = [[A]]) {{.*}} + %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) { + // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]] + // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]] + // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]] + // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB]][[[IV]]] + %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %5 = arith.addf %3, %2 : vector<16xf32> + %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + scf.yield %6 : tensor<128xf32> + } + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { + // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]] + // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]] + // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IA]][[[IV]]] + %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + // NB: the dominance check used to fail on the following line, + // however the defining op for the value of %arg3 occurs above the source loop and hence is safe + // and %arg4 is a block argument of the scope of the loops and hence is safe + %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> + %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}} + scf.yield %dup6 : tensor<128xf32> + } + return %1, %dup1 : tensor<128xf32>, tensor<128xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK: func.func @matmul_fuse_1st_forall_into_2nd([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +func.func @matmul_fuse_1st_forall_into_2nd(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { %zero = arith.constant 0.0 : f32 %out_alloc = tensor.empty() : tensor<128x128xf32> %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) { // CHECK: [[T:%.*]] = affine.apply + // CHECK: tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1] // CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1] // CHECK: [[OUT1:%.*]] = linalg.matmul + // CHECK: tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1] // CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1] // CHECK: [[OUT2:%.*]] = linalg.matmul // CHECK: scf.forall.in_parallel { @@ -16,12 +115,11 @@ func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tenso // CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1] // CHECK: } // CHECK: } - %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> - %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) { %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) @@ -31,25 +129,37 @@ module attributes {transform.with_named_sequence} { %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op + %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op transform.yield } } // ----- -func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { +// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { %zero = arith.constant 0.0 : f32 %out_alloc = tensor.empty() : tensor<128x128xf32> %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> - // expected-error @below {{user of results of target should be properly dominated by source}} - %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> - %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK: [[T:%.*]] = affine.apply + // CHECK: tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: [[OUT1:%.*]] = linalg.matmul + // CHECK: tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: [[OUT2:%.*]] = linalg.matmul + // CHECK: scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: } + // CHECK: } + %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) { %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) @@ -66,18 +176,84 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { +// CHECK-NOCLEANUP: func.func @fuse_no_iter_args([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +func.func @fuse_no_iter_args(%A: tensor<128xf32>, %B: tensor<128xf32>) { + // CHECK-NOCLEANUP: [[C0:%.*]] = arith.constant 0 : index + // CHECK-NOCLEANUP: [[C16:%.*]] = arith.constant 16 : index + // CHECK-NOCLEANUP: [[C128:%.*]] = arith.constant 128 : index + // CHECK-NOCLEANUP: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + // CHECK-NOCLEANUP: scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] {{.*}} + scf.for %arg0 = %c0 to %c128 step %c16 { + // CHECK-NOCLEANUP: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]] + %2 = vector.transfer_read %A[%arg0], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + scf.yield + } + scf.for %arg0 = %c0 to %c128 step %c16 { + // CHECK-NOCLEANUP: [[BSLICE:%.*]] = vector.transfer_read [[B]][[[IV]]], [[ZERO]] + %dup2 = vector.transfer_read %B[%arg0], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + scf.yield + } + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + // expected-error @below {{user of results of target should be properly dominated by source}} + %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) { + %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %5 = arith.addf %3, %2 : vector<16xf32> + %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + scf.yield %6 : tensor<128xf32> + } + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %1) -> (tensor<128xf32>) { + %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> + %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + scf.yield %dup6 : tensor<128xf32> + } + return %1, %dup1 : tensor<128xf32>, tensor<128xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @source_forall_uses_result_of_target_forall_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { %zero = arith.constant 0.0 : f32 %out_alloc = tensor.empty() : tensor<128x128xf32> %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> + // expected-error @below {{user of results of target should be properly dominated by source}} %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> - // expected-error @below {{values used inside regions of target should be properly dominated by source}} %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) { %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) @@ -87,25 +263,58 @@ module attributes {transform.with_named_sequence} { %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op + %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op transform.yield } } // ----- -func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { - %zero = arith.constant 0.0 : f32 - %out_alloc = tensor.empty() : tensor<128x128xf32> - %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> +func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) { + %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %5 = arith.addf %3, %2 : vector<16xf32> + %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + scf.yield %6 : tensor<128xf32> + } + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { + // expected-error @below {{values used inside regions of target should be properly dominated by source}} + %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> + %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> + %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> + scf.yield %dup6 : tensor<128xf32> + } + return %1, %dup1 : tensor<128xf32>, tensor<128xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} - %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> +// ----- + +func.func @target_forall_depends_on_value_not_dominated_by_source_forall_err(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %zero = arith.constant 0.0 : f32 + %buf1_alloc = tensor.empty() : tensor<128x128xf32> + %buf1 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> + %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf1 : tensor<128x128xf32>) -> tensor<128x128xf32> + %out_alloc2 = tensor.empty() : tensor<128x128xf32> + %buf2 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> // expected-error @below {{operands of target should be properly dominated by source}} - %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out1 : tensor<128x128xf32>) -> tensor<128x128xf32> + %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf2 : tensor<128x128xf32>) -> tensor<128x128xf32> func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) { %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) |