diff options
-rw-r--r-- | mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 3 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 20 | ||||
-rw-r--r-- | mlir/include/mlir/Interfaces/LoopLikeInterface.h | 20 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 38 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp | 140 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp | 80 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Utils/Utils.cpp | 279 | ||||
-rw-r--r-- | mlir/lib/Interfaces/LoopLikeInterface.cpp | 55 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir | 234 |
9 files changed, 586 insertions, 283 deletions
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index f35ea962..bf95fbe 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -303,7 +303,8 @@ def ForallOp : SCF_Op<"forall", [ DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps", - "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>, + "replaceWithAdditionalYields", "promoteIfSingleIteration", + "yieldTiledValuesAndReplace"]>, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, DeclareOpInterfaceMethods<RegionBranchOpInterface>, diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index de807c3..6a40304 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -181,6 +181,16 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes); void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops, scf::ForOp root); +//===----------------------------------------------------------------------===// +// Fusion related helpers +//===----------------------------------------------------------------------===// + +/// Check structural compatibility between two loops such as iteration space +/// and dominance. +bool checkFusionStructuralLegality(LoopLikeOpInterface target, + LoopLikeOpInterface source, + Diagnostic &diag); + /// Given two scf.forall loops, `target` and `source`, fuses `target` into /// `source`. Assumes that the given loops are siblings and are independent of /// each other. @@ -202,6 +212,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter); +/// Given two scf.parallel loops, `target` and `source`, fuses `target` into +/// `source`. Assumes that the given loops are siblings and are independent of +/// each other. +/// +/// This function does not perform any legality checks and simply fuses the +/// loops. The caller is responsible for ensuring that the loops are legal to +/// fuse. +scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target, + scf::ParallelOp source, + RewriterBase &rewriter); } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h index 9925fc6..d08e097 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -90,4 +90,24 @@ struct JamBlockGatherer { /// Include the generated interface declarations. #include "mlir/Interfaces/LoopLikeInterface.h.inc" +namespace mlir { +/// A function that rewrites `target`'s terminator as a teminator obtained by +/// fusing `source` into `target`. +using FuseTerminatorFn = + function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping)>; + +/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to +/// `target`. The `NewYieldValuesFn` callback is used to pass to the +/// `replaceWithAdditionalYields` interface method to replace the loop with a +/// new loop with (possibly) additional yields, while the `FuseTerminatorFn` +/// callback is repsonsible for updating the fused loop terminator. +LoopLikeOpInterface createFused(LoopLikeOpInterface target, + LoopLikeOpInterface source, + RewriterBase &rewriter, + NewYieldValuesFn newYieldValuesFn, + FuseTerminatorFn fuseTerminatorFn); + +} // namespace mlir + #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_ diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 907d7f7..cb15e0e 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -618,6 +618,44 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } +FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields( + RewriterBase &rewriter, ValueRange newInitOperands, + bool replaceInitOperandUsesInLoop, + const NewYieldValuesFn &newYieldValuesFn) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(getOperation()); + SmallVector<Value> inits(getOutputs()); + llvm::append_range(inits, newInitOperands); + scf::ForallOp newLoop = rewriter.create<scf::ForallOp>( + getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), + inits, getMapping(), + /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); + + // Move the loop body to the new op. + rewriter.mergeBlocks(getBody(), newLoop.getBody(), + newLoop.getBody()->getArguments().take_front( + getBody()->getNumArguments())); + + if (replaceInitOperandUsesInLoop) { + // Replace all uses of `newInitOperands` with the corresponding basic block + // arguments. + for (auto &&[newOperand, oldOperand] : + llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back( + newInitOperands.size()))) { + rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + } + + // Replace the old loop. + rewriter.replaceOp(getOperation(), + newLoop->getResults().take_front(getNumResults())); + return cast<LoopLikeOpInterface>(newLoop.getOperation()); +} + /// Promotes the loop body of a forallOp to its containing block if it can be /// determined that the loop has a single iteration. LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 56ff270..41834fe 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp, return 1; }; - std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound()); - std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound()); + std::optional<int64_t> ubConstant = + getConstantIntValue(forOp.getUpperBound()); + std::optional<int64_t> lbConstant = + getConstantIntValue(forOp.getLowerBound()); DenseMap<Operation *, unsigned> opCycles; std::map<unsigned, std::vector<Operation *>> wrappedSchedule; for (Operation &op : forOp.getBody()->getOperations()) { @@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects( // LoopFuseSiblingOp //===----------------------------------------------------------------------===// -/// Check if `target` and `source` are siblings, in the context that `target` -/// is being fused into `source`. -/// -/// This is a simple check that just checks if both operations are in the same -/// block and some checks to ensure that the fused IR does not violate -/// dominance. -static DiagnosedSilenceableFailure isOpSibling(Operation *target, - Operation *source) { - // Check if both operations are same. - if (target == source) - return emitSilenceableFailure(source) - << "target and source need to be different loops"; - - // Check if both operations are in the same block. - if (target->getBlock() != source->getBlock()) - return emitSilenceableFailure(source) - << "target and source are not in the same block"; - - // Check if fusion will violate dominance. - DominanceInfo domInfo(source); - if (target->isBeforeInBlock(source)) { - // Since `target` is before `source`, all users of results of `target` - // need to be dominated by `source`. - for (Operation *user : target->getUsers()) { - if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { - return emitSilenceableFailure(target) - << "user of results of target should be properly dominated by " - "source"; - } - } - } else { - // Since `target` is after `source`, all values used by `target` need - // to dominate `source`. - - // Check if operands of `target` are dominated by `source`. - for (Value operand : target->getOperands()) { - Operation *operandOp = operand.getDefiningOp(); - // Operands without defining operations are block arguments. When `target` - // and `source` occur in the same block, these operands dominate `source`. - if (!operandOp) - continue; - - // Operand's defining operation should properly dominate `source`. - if (!domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) - return emitSilenceableFailure(target) - << "operands of target should be properly dominated by source"; - } - - // Check if values used by `target` are dominated by `source`. - bool failed = false; - OpOperand *failedValue = nullptr; - visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { - Operation *operandOp = operand->get().getDefiningOp(); - if (operandOp && !domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) { - // `operand` is not an argument of an enclosing block and the defining - // op of `operand` is outside `target` but does not dominate `source`. - failed = true; - failedValue = operand; - } - }); - - if (failed) - return emitSilenceableFailure(failedValue->getOwner()) - << "values used inside regions of target should be properly " - "dominated by source"; - } - - return DiagnosedSilenceableFailure::success(); -} - -/// Check if `target` scf.forall can be fused into `source` scf.forall. -/// -/// This simply checks if both loops have the same bounds, steps and mapping. -/// No attempt is made at checking that the side effects of `target` and -/// `source` are independent of each other. -static bool isForallWithIdenticalConfiguration(Operation *target, - Operation *source) { - auto targetOp = dyn_cast<scf::ForallOp>(target); - auto sourceOp = dyn_cast<scf::ForallOp>(source); - if (!targetOp || !sourceOp) - return false; - - return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && - targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && - targetOp.getMixedStep() == sourceOp.getMixedStep() && - targetOp.getMapping() == sourceOp.getMapping(); -} - -/// Check if `target` scf.for can be fused into `source` scf.for. -/// -/// This simply checks if both loops have the same bounds and steps. No attempt -/// is made at checking that the side effects of `target` and `source` are -/// independent of each other. -static bool isForWithIdenticalConfiguration(Operation *target, - Operation *source) { - auto targetOp = dyn_cast<scf::ForOp>(target); - auto sourceOp = dyn_cast<scf::ForOp>(source); - if (!targetOp || !sourceOp) - return false; - - return targetOp.getLowerBound() == sourceOp.getLowerBound() && - targetOp.getUpperBound() == sourceOp.getUpperBound() && - targetOp.getStep() == sourceOp.getStep(); -} - DiagnosedSilenceableFailure transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, << "source handle (got " << llvm::range_size(sourceOps) << ")"; } - Operation *target = *targetOps.begin(); - Operation *source = *sourceOps.begin(); + auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin()); + auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin()); + if (!target || !source) + return emitSilenceableFailure(target->getLoc()) + << "target or source is not a loop op"; - // Check if the target and source are siblings. - DiagnosedSilenceableFailure diag = isOpSibling(target, source); - if (!diag.succeeded()) - return diag; + // Check if loops can be fused + Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error); + if (!mlir::checkFusionStructuralLegality(target, source, diag)) + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); Operation *fusedLoop; - /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. - if (isForWithIdenticalConfiguration(target, source)) { + // TODO: Support fusion for loop-like ops besides scf.for, scf.forall + // and scf.parallel. + if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) { fusedLoop = fuseIndependentSiblingForLoops( cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter); - } else if (isForallWithIdenticalConfiguration(target, source)) { + } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) { fusedLoop = fuseIndependentSiblingForallLoops( cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter); + } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) { + fusedLoop = fuseIndependentSiblingParallelLoops( + cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter); } else return emitSilenceableFailure(target->getLoc()) - << "operations cannot be fused"; + << "unsupported loop type for fusion"; assert(fusedLoop && "failed to fuse operations"); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 5934d85..b775f98 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" @@ -37,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) { return walkResult.wasInterrupted(); } -/// Verify equal iteration spaces. -static bool equalIterationSpaces(ParallelOp firstPloop, - ParallelOp secondPloop) { - if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) - return false; - - auto matchOperands = [&](const OperandRange &lhs, - const OperandRange &rhs) -> bool { - // TODO: Extend this to support aliases and equal constants. - return std::equal(lhs.begin(), lhs.end(), rhs.begin()); - }; - return matchOperands(firstPloop.getLowerBound(), - secondPloop.getLowerBound()) && - matchOperands(firstPloop.getUpperBound(), - secondPloop.getUpperBound()) && - matchOperands(firstPloop.getStep(), secondPloop.getStep()); -} - /// Checks if the parallel loops have mixed access to the same buffers. Returns /// `true` if the first parallel loop writes to the same indices that the second /// loop reads. @@ -153,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref<bool(Value, Value)> mayAlias) { + Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark); return !hasNestedParallelOp(firstPloop) && !hasNestedParallelOp(secondPloop) && - equalIterationSpaces(firstPloop, secondPloop) && + checkFusionStructuralLegality(firstPloop, secondPloop, diag) && succeeded(verifyDependencies(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)); } @@ -174,61 +158,9 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, mayAlias)) return; - DominanceInfo dom; - // We are fusing first loop into second, make sure there are no users of the - // first loop results between loops. - for (Operation *user : firstPloop->getUsers()) - if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) - return; - - ValueRange inits1 = firstPloop.getInitVals(); - ValueRange inits2 = secondPloop.getInitVals(); - - SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); - newInitVars.append(inits2.begin(), inits2.end()); - - IRRewriter b(builder); - b.setInsertionPoint(secondPloop); - auto newSecondPloop = b.create<ParallelOp>( - secondPloop.getLoc(), secondPloop.getLowerBound(), - secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); - - Block *newBlock = newSecondPloop.getBody(); - auto term1 = cast<ReduceOp>(block1->getTerminator()); - auto term2 = cast<ReduceOp>(block2->getTerminator()); - - b.inlineBlockBefore(block2, newBlock, newBlock->begin(), - newBlock->getArguments()); - b.inlineBlockBefore(block1, newBlock, newBlock->begin(), - newBlock->getArguments()); - - ValueRange results = newSecondPloop.getResults(); - if (!results.empty()) { - b.setInsertionPointToEnd(newBlock); - - ValueRange reduceArgs1 = term1.getOperands(); - ValueRange reduceArgs2 = term2.getOperands(); - SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); - newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - - auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs); - - for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( - term1.getReductions(), term2.getReductions()))) { - Block &oldRedBlock = reg.front(); - Block &newRedBlock = newReduceOp.getReductions()[i].front(); - b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), - newRedBlock.getArguments()); - } - - firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); - secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); - } - term1->erase(); - term2->erase(); - firstPloop.erase(); - secondPloop.erase(); - secondPloop = newSecondPloop; + IRRewriter rewriter(builder); + secondPloop = mlir::fuseIndependentSiblingParallelLoops( + firstPloop, secondPloop, rewriter); } void mlir::scf::naivelyFuseParallelOps( diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index c0ee9d2..abfc9a1 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -1262,54 +1263,131 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, return tileLoops; } +//===----------------------------------------------------------------------===// +// Fusion related helpers +//===----------------------------------------------------------------------===// + +/// Check if `target` and `source` are siblings, in the context that `target` +/// is being fused into `source`. +/// +/// This is a simple check that just checks if both operations are in the same +/// block and some checks to ensure that the fused IR does not violate +/// dominance. +static bool isOpSibling(Operation *target, Operation *source, + Diagnostic &diag) { + // Check if both operations are same. + if (target == source) { + diag << "target and source need to be different loops"; + return false; + } + + // Check if both operations are in the same block. + if (target->getBlock() != source->getBlock()) { + diag << "target and source are not in the same block"; + return false; + } + + // Check if fusion will violate dominance. + DominanceInfo domInfo(source); + if (target->isBeforeInBlock(source)) { + // Since `target` is before `source`, all users of results of `target` + // need to be dominated by `source`. + for (Operation *user : target->getUsers()) { + if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { + diag << "user of results of target should " + "be properly dominated by source"; + return false; + } + } + } else { + // Since `target` is after `source`, all values used by `target` need + // to dominate `source`. + + // Check if operands of `target` are dominated by `source`. + for (Value operand : target->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + // Operands without defining operations are block arguments. When `target` + // and `source` occur in the same block, these operands dominate `source`. + if (!operandOp) + continue; + + // Operand's defining operation should properly dominate `source`. + if (!domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + diag << "operands of target should be properly dominated by source"; + return false; + } + } + + // Check if values used by `target` are dominated by `source`. + bool failed = false; + OpOperand *failedValue = nullptr; + visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { + Operation *operandOp = operand->get().getDefiningOp(); + if (operandOp && !domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + // `operand` is not an argument of an enclosing block and the defining + // op of `operand` is outside `target` but does not dominate `source`. + failed = true; + failedValue = operand; + } + }); + + if (failed) { + diag << "values used inside regions of target should be properly " + "dominated by source"; + diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation"; + return false; + } + } + + return true; +} + +bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, + LoopLikeOpInterface source, + Diagnostic &diag) { + if (target->getName() != source->getName()) { + diag << "target and source must be same loop type"; + return false; + } + + bool iterSpaceEq = + target.getLoopLowerBounds() == source.getLoopLowerBounds() && + target.getLoopUpperBounds() == source.getLoopUpperBounds() && + target.getLoopSteps() == source.getLoopSteps(); + // TODO: Decouple checks on concrete loop types and move this function + // somewhere for general utility for `LoopLikeOpInterface` + if (auto forAllTarget = dyn_cast<scf::ForallOp>(*target)) + iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() == + cast<scf::ForallOp>(*source).getMapping(); + if (!iterSpaceEq) { + diag << "target and source iteration spaces must be equal"; + return false; + } + return isOpSibling(target, source, diag); +} + scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { - unsigned numTargetOuts = target.getNumResults(); - unsigned numSourceOuts = source.getNumResults(); - - // Create fused shared_outs. - SmallVector<Value> fusedOuts; - llvm::append_range(fusedOuts, target.getOutputs()); - llvm::append_range(fusedOuts, source.getOutputs()); - - // Create a new scf.forall op after the source loop. - rewriter.setInsertionPointAfter(source); - scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>( - source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), - source.getMixedStep(), fusedOuts, source.getMapping()); - - // Map control operands. - IRMapping mapping; - mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); - mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); - - // Map shared outs. - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - - // Append everything except the terminator into the fused operation. - rewriter.setInsertionPointToStart(fusedLoop.getBody()); - for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, mapping); - for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, mapping); - - // Fuse the old terminator in_parallel ops into the new one. - scf::InParallelOp targetTerm = target.getTerminator(); - scf::InParallelOp sourceTerm = source.getTerminator(); - scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); - rewriter.setInsertionPointToStart(fusedTerm.getBody()); - for (Operation &op : targetTerm.getYieldingOps()) - rewriter.clone(op, mapping); - for (Operation &op : sourceTerm.getYieldingOps()) - rewriter.clone(op, mapping); - - // Replace old loops by substituting their uses by results of the fused loop. - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused( + target, source, rewriter, + [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) { + // `ForallOp` does not have yields, rather an `InParallelOp` terminator. + return ValueRange{}; + }, + [&](RewriterBase &b, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping) { + auto sourceForall = cast<scf::ForallOp>(source); + auto targetForall = cast<scf::ForallOp>(target); + scf::InParallelOp fusedTerm = targetForall.getTerminator(); + b.setInsertionPointToEnd(fusedTerm.getBody()); + for (Operation &op : sourceForall.getTerminator().getYieldingOps()) + b.clone(op, mapping); + })); + rewriter.replaceOp(source, + fusedLoop.getResults().take_back(source.getNumResults())); return fusedLoop; } @@ -1317,49 +1395,74 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { - unsigned numTargetOuts = target.getNumResults(); - unsigned numSourceOuts = source.getNumResults(); - - // Create fused init_args, with target's init_args before source's init_args. - SmallVector<Value> fusedInitArgs; - llvm::append_range(fusedInitArgs, target.getInitArgs()); - llvm::append_range(fusedInitArgs, source.getInitArgs()); - - // Create a new scf.for op after the source loop (with scf.yield terminator - // (without arguments) only in case its init_args is empty). - rewriter.setInsertionPointAfter(source); - scf::ForOp fusedLoop = rewriter.create<scf::ForOp>( - source.getLoc(), source.getLowerBound(), source.getUpperBound(), - source.getStep(), fusedInitArgs); - - // Map original induction variables and operands to those of the fused loop. - IRMapping mapping; - mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - - // Merge target's body into the new (fused) for loop and then source's body. - rewriter.setInsertionPointToStart(fusedLoop.getBody()); - for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, mapping); - for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, mapping); - - // Build fused yield results by appropriately mapping original yield operands. - SmallVector<Value> yieldResults; - for (Value operand : target.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - for (Value operand : source.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - if (!yieldResults.empty()) - rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults); - - // Replace old loops by substituting their uses by results of the fused loop. - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + scf::ForOp fusedLoop = cast<scf::ForOp>(createFused( + target, source, rewriter, + [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) { + return source.getYieldedValues(); + }, + [&](RewriterBase &b, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping) { + auto targetFor = cast<scf::ForOp>(target); + auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping); + b.replaceOp(targetFor.getBody()->getTerminator(), newTerm); + })); + rewriter.replaceOp(source, + fusedLoop.getResults().take_back(source.getNumResults())); + return fusedLoop; +} + +// TODO: Finish refactoring this a la the above, but likely requires additional +// interface methods. +scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( + scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + Block *block1 = target.getBody(); + Block *block2 = source.getBody(); + auto term1 = cast<scf::ReduceOp>(block1->getTerminator()); + auto term2 = cast<scf::ReduceOp>(block2->getTerminator()); + + ValueRange inits1 = target.getInitVals(); + ValueRange inits2 = source.getInitVals(); + + SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + rewriter.setInsertionPoint(source); + auto fusedLoop = rewriter.create<scf::ParallelOp>( + rewriter.getFusedLoc(target.getLoc(), source.getLoc()), + source.getLowerBound(), source.getUpperBound(), source.getStep(), + newInitVars); + Block *newBlock = fusedLoop.getBody(); + rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(), + newBlock->getArguments()); + rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(), + newBlock->getArguments()); + + ValueRange results = fusedLoop.getResults(); + if (!results.empty()) { + rewriter.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = rewriter.create<scf::ReduceOp>( + rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock, + newRedBlock.begin(), + newRedBlock.getArguments()); + } + } + rewriter.replaceOp(target, results.take_front(inits1.size())); + rewriter.replaceOp(source, results.take_back(inits2.size())); + rewriter.eraseOp(term1); + rewriter.eraseOp(term2); return fusedLoop; } diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp index 1e0e87b..6f0ebec 100644 --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -8,6 +8,8 @@ #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/ADT/DenseSet.h" @@ -113,3 +115,56 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { return success(); } + +LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target, + LoopLikeOpInterface source, + RewriterBase &rewriter, + NewYieldValuesFn newYieldValuesFn, + FuseTerminatorFn fuseTerminatorFn) { + auto targetIterArgs = target.getRegionIterArgs(); + std::optional<SmallVector<Value>> targetInductionVar = + target.getLoopInductionVars(); + SmallVector<Value> targetYieldOperands(target.getYieldedValues()); + auto sourceIterArgs = source.getRegionIterArgs(); + std::optional<SmallVector<Value>> sourceInductionVar = + *source.getLoopInductionVars(); + SmallVector<Value> sourceYieldOperands(source.getYieldedValues()); + auto sourceRegion = source.getLoopRegions().front(); + + FailureOr<LoopLikeOpInterface> maybeFusedLoop = + target.replaceWithAdditionalYields(rewriter, source.getInits(), + /*replaceInitOperandUsesInLoop=*/false, + newYieldValuesFn); + if (failed(maybeFusedLoop)) + llvm_unreachable("failed to replace loop"); + LoopLikeOpInterface fusedLoop = *maybeFusedLoop; + + // Map control operands. + IRMapping mapping; + std::optional<SmallVector<Value>> fusedInductionVar = + fusedLoop.getLoopInductionVars(); + if (fusedInductionVar) { + if (!targetInductionVar || !sourceInductionVar) + llvm_unreachable("expected target and source loops to have induction vars"); + mapping.map(*targetInductionVar, *fusedInductionVar); + mapping.map(*sourceInductionVar, *fusedInductionVar); + } + mapping.map(targetIterArgs, + fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); + mapping.map(targetYieldOperands, + fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); + mapping.map(sourceIterArgs, + fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); + mapping.map(sourceYieldOperands, + fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); + // Append everything except the terminator into the fused operation. + rewriter.setInsertionPoint( + fusedLoop.getLoopRegions().front()->front().getTerminator()); + for (Operation &op : sourceRegion->front().without_terminator()) + rewriter.clone(op, mapping); + + // TODO: Replace with corresponding interface method if added + fuseTerminatorFn(rewriter, source, fusedLoop, mapping); + + return fusedLoop; +} diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 54dd2bd..91ed2a5 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -47,6 +47,169 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @fuse_two_parallel +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 +// CHECK: [[SUM:%.*]] = memref.alloc() + %sum = memref.alloc() : memref<2x2xf32> +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] +// CHECK: scf.reduce +// CHECK: } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } +// CHECK: memref.dealloc [[SUM]] + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @fuse_two_parallel_reverse +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 +// CHECK: [[SUM:%.*]] = memref.alloc() + %sum = memref.alloc() : memref<2x2xf32> +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: scf.reduce +// CHECK: } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } +// CHECK: memref.dealloc [[SUM]] + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @fuse_reductions_two +// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32) +func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) +// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32) +// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] +// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32 + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2 : f32, f32 +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index @@ -282,8 +445,9 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> scf.yield %6 : tensor<128xf32> } - %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { // expected-error @below {{values used inside regions of target should be properly dominated by source}} + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { + // expected-note @below {{see operation}} %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> @@ -328,6 +492,74 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 + %sum = memref.alloc() : memref<2x2xf32> + // expected-error @below {{target and source iteration spaces must be equal}} + scf.parallel (%i) = (%c0) to (%c2) step (%c1) { + %B_elem = memref.load %B[%i, %c0] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 + %sum = memref.alloc() : memref<2xf32> + // expected-error @below {{target and source must be same loop type}} + scf.for %i = %c0 to %c2 step %c1 { + %B_elem = memref.load %B[%i] : memref<2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i] : memref<2xf32> + } + scf.parallel (%i) = (%c0) to (%c2) step (%c1) { + %sum_elem = memref.load %sum[%i] : memref<2xf32> + %A_elem = memref.load %A[%i] : memref<2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i] : memref<2xf32> + scf.reduce + } + memref.dealloc %sum : memref<2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + // ----- // CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} |