diff options
author | srcarroll <50210727+srcarroll@users.noreply.github.com> | 2024-07-02 11:12:51 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-02 11:12:51 -0500 |
commit | 6820b0871807abff07df118659e0de2ca741cb0b (patch) | |
tree | e6d4b4637118cd50e3cf4dfc0a1a421e72ca1cc8 /mlir/lib | |
parent | 6c3897d90eda4c39789ac9f4efa51db46734a249 (diff) | |
download | llvm-6820b0871807abff07df118659e0de2ca741cb0b.zip llvm-6820b0871807abff07df118659e0de2ca741cb0b.tar.gz llvm-6820b0871807abff07df118659e0de2ca741cb0b.tar.bz2 |
Refactor LoopFuseSiblingOp and support parallel fusion (#94391)
This patch refactors code related to `LoopFuseSiblingOp` transform in
attempt to reduce duplicate common code. The aim is to refactor as much
as possible to a functions on `LoopLikeOpInterface`s, but this is still
a work in progress. A full refactor will require more additions to the
`LoopLikeOpInterface`.
In addition, `scf.parallel` fusion support has been added.
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 38 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp | 140 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp | 80 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Utils/Utils.cpp | 279 | ||||
-rw-r--r-- | mlir/lib/Interfaces/LoopLikeInterface.cpp | 55 |
5 files changed, 311 insertions, 281 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 907d7f7..cb15e0e 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -618,6 +618,44 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } +FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields( + RewriterBase &rewriter, ValueRange newInitOperands, + bool replaceInitOperandUsesInLoop, + const NewYieldValuesFn &newYieldValuesFn) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(getOperation()); + SmallVector<Value> inits(getOutputs()); + llvm::append_range(inits, newInitOperands); + scf::ForallOp newLoop = rewriter.create<scf::ForallOp>( + getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), + inits, getMapping(), + /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); + + // Move the loop body to the new op. + rewriter.mergeBlocks(getBody(), newLoop.getBody(), + newLoop.getBody()->getArguments().take_front( + getBody()->getNumArguments())); + + if (replaceInitOperandUsesInLoop) { + // Replace all uses of `newInitOperands` with the corresponding basic block + // arguments. + for (auto &&[newOperand, oldOperand] : + llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back( + newInitOperands.size()))) { + rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + } + + // Replace the old loop. + rewriter.replaceOp(getOperation(), + newLoop->getResults().take_front(getNumResults())); + return cast<LoopLikeOpInterface>(newLoop.getOperation()); +} + /// Promotes the loop body of a forallOp to its containing block if it can be /// determined that the loop has a single iteration. LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 56ff270..41834fe 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp, return 1; }; - std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound()); - std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound()); + std::optional<int64_t> ubConstant = + getConstantIntValue(forOp.getUpperBound()); + std::optional<int64_t> lbConstant = + getConstantIntValue(forOp.getLowerBound()); DenseMap<Operation *, unsigned> opCycles; std::map<unsigned, std::vector<Operation *>> wrappedSchedule; for (Operation &op : forOp.getBody()->getOperations()) { @@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects( // LoopFuseSiblingOp //===----------------------------------------------------------------------===// -/// Check if `target` and `source` are siblings, in the context that `target` -/// is being fused into `source`. -/// -/// This is a simple check that just checks if both operations are in the same -/// block and some checks to ensure that the fused IR does not violate -/// dominance. -static DiagnosedSilenceableFailure isOpSibling(Operation *target, - Operation *source) { - // Check if both operations are same. - if (target == source) - return emitSilenceableFailure(source) - << "target and source need to be different loops"; - - // Check if both operations are in the same block. - if (target->getBlock() != source->getBlock()) - return emitSilenceableFailure(source) - << "target and source are not in the same block"; - - // Check if fusion will violate dominance. - DominanceInfo domInfo(source); - if (target->isBeforeInBlock(source)) { - // Since `target` is before `source`, all users of results of `target` - // need to be dominated by `source`. - for (Operation *user : target->getUsers()) { - if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { - return emitSilenceableFailure(target) - << "user of results of target should be properly dominated by " - "source"; - } - } - } else { - // Since `target` is after `source`, all values used by `target` need - // to dominate `source`. - - // Check if operands of `target` are dominated by `source`. - for (Value operand : target->getOperands()) { - Operation *operandOp = operand.getDefiningOp(); - // Operands without defining operations are block arguments. When `target` - // and `source` occur in the same block, these operands dominate `source`. - if (!operandOp) - continue; - - // Operand's defining operation should properly dominate `source`. - if (!domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) - return emitSilenceableFailure(target) - << "operands of target should be properly dominated by source"; - } - - // Check if values used by `target` are dominated by `source`. - bool failed = false; - OpOperand *failedValue = nullptr; - visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { - Operation *operandOp = operand->get().getDefiningOp(); - if (operandOp && !domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) { - // `operand` is not an argument of an enclosing block and the defining - // op of `operand` is outside `target` but does not dominate `source`. - failed = true; - failedValue = operand; - } - }); - - if (failed) - return emitSilenceableFailure(failedValue->getOwner()) - << "values used inside regions of target should be properly " - "dominated by source"; - } - - return DiagnosedSilenceableFailure::success(); -} - -/// Check if `target` scf.forall can be fused into `source` scf.forall. -/// -/// This simply checks if both loops have the same bounds, steps and mapping. -/// No attempt is made at checking that the side effects of `target` and -/// `source` are independent of each other. -static bool isForallWithIdenticalConfiguration(Operation *target, - Operation *source) { - auto targetOp = dyn_cast<scf::ForallOp>(target); - auto sourceOp = dyn_cast<scf::ForallOp>(source); - if (!targetOp || !sourceOp) - return false; - - return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && - targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && - targetOp.getMixedStep() == sourceOp.getMixedStep() && - targetOp.getMapping() == sourceOp.getMapping(); -} - -/// Check if `target` scf.for can be fused into `source` scf.for. -/// -/// This simply checks if both loops have the same bounds and steps. No attempt -/// is made at checking that the side effects of `target` and `source` are -/// independent of each other. -static bool isForWithIdenticalConfiguration(Operation *target, - Operation *source) { - auto targetOp = dyn_cast<scf::ForOp>(target); - auto sourceOp = dyn_cast<scf::ForOp>(source); - if (!targetOp || !sourceOp) - return false; - - return targetOp.getLowerBound() == sourceOp.getLowerBound() && - targetOp.getUpperBound() == sourceOp.getUpperBound() && - targetOp.getStep() == sourceOp.getStep(); -} - DiagnosedSilenceableFailure transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, << "source handle (got " << llvm::range_size(sourceOps) << ")"; } - Operation *target = *targetOps.begin(); - Operation *source = *sourceOps.begin(); + auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin()); + auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin()); + if (!target || !source) + return emitSilenceableFailure(target->getLoc()) + << "target or source is not a loop op"; - // Check if the target and source are siblings. - DiagnosedSilenceableFailure diag = isOpSibling(target, source); - if (!diag.succeeded()) - return diag; + // Check if loops can be fused + Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error); + if (!mlir::checkFusionStructuralLegality(target, source, diag)) + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); Operation *fusedLoop; - /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. - if (isForWithIdenticalConfiguration(target, source)) { + // TODO: Support fusion for loop-like ops besides scf.for, scf.forall + // and scf.parallel. + if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) { fusedLoop = fuseIndependentSiblingForLoops( cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter); - } else if (isForallWithIdenticalConfiguration(target, source)) { + } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) { fusedLoop = fuseIndependentSiblingForallLoops( cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter); + } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) { + fusedLoop = fuseIndependentSiblingParallelLoops( + cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter); } else return emitSilenceableFailure(target->getLoc()) - << "operations cannot be fused"; + << "unsupported loop type for fusion"; assert(fusedLoop && "failed to fuse operations"); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 5934d85..b775f98 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" @@ -37,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) { return walkResult.wasInterrupted(); } -/// Verify equal iteration spaces. -static bool equalIterationSpaces(ParallelOp firstPloop, - ParallelOp secondPloop) { - if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) - return false; - - auto matchOperands = [&](const OperandRange &lhs, - const OperandRange &rhs) -> bool { - // TODO: Extend this to support aliases and equal constants. - return std::equal(lhs.begin(), lhs.end(), rhs.begin()); - }; - return matchOperands(firstPloop.getLowerBound(), - secondPloop.getLowerBound()) && - matchOperands(firstPloop.getUpperBound(), - secondPloop.getUpperBound()) && - matchOperands(firstPloop.getStep(), secondPloop.getStep()); -} - /// Checks if the parallel loops have mixed access to the same buffers. Returns /// `true` if the first parallel loop writes to the same indices that the second /// loop reads. @@ -153,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref<bool(Value, Value)> mayAlias) { + Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark); return !hasNestedParallelOp(firstPloop) && !hasNestedParallelOp(secondPloop) && - equalIterationSpaces(firstPloop, secondPloop) && + checkFusionStructuralLegality(firstPloop, secondPloop, diag) && succeeded(verifyDependencies(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)); } @@ -174,61 +158,9 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, mayAlias)) return; - DominanceInfo dom; - // We are fusing first loop into second, make sure there are no users of the - // first loop results between loops. - for (Operation *user : firstPloop->getUsers()) - if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) - return; - - ValueRange inits1 = firstPloop.getInitVals(); - ValueRange inits2 = secondPloop.getInitVals(); - - SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); - newInitVars.append(inits2.begin(), inits2.end()); - - IRRewriter b(builder); - b.setInsertionPoint(secondPloop); - auto newSecondPloop = b.create<ParallelOp>( - secondPloop.getLoc(), secondPloop.getLowerBound(), - secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); - - Block *newBlock = newSecondPloop.getBody(); - auto term1 = cast<ReduceOp>(block1->getTerminator()); - auto term2 = cast<ReduceOp>(block2->getTerminator()); - - b.inlineBlockBefore(block2, newBlock, newBlock->begin(), - newBlock->getArguments()); - b.inlineBlockBefore(block1, newBlock, newBlock->begin(), - newBlock->getArguments()); - - ValueRange results = newSecondPloop.getResults(); - if (!results.empty()) { - b.setInsertionPointToEnd(newBlock); - - ValueRange reduceArgs1 = term1.getOperands(); - ValueRange reduceArgs2 = term2.getOperands(); - SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); - newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - - auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs); - - for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( - term1.getReductions(), term2.getReductions()))) { - Block &oldRedBlock = reg.front(); - Block &newRedBlock = newReduceOp.getReductions()[i].front(); - b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), - newRedBlock.getArguments()); - } - - firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); - secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); - } - term1->erase(); - term2->erase(); - firstPloop.erase(); - secondPloop.erase(); - secondPloop = newSecondPloop; + IRRewriter rewriter(builder); + secondPloop = mlir::fuseIndependentSiblingParallelLoops( + firstPloop, secondPloop, rewriter); } void mlir::scf::naivelyFuseParallelOps( diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index c0ee9d2..abfc9a1 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -1262,54 +1263,131 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, return tileLoops; } +//===----------------------------------------------------------------------===// +// Fusion related helpers +//===----------------------------------------------------------------------===// + +/// Check if `target` and `source` are siblings, in the context that `target` +/// is being fused into `source`. +/// +/// This is a simple check that just checks if both operations are in the same +/// block and some checks to ensure that the fused IR does not violate +/// dominance. +static bool isOpSibling(Operation *target, Operation *source, + Diagnostic &diag) { + // Check if both operations are same. + if (target == source) { + diag << "target and source need to be different loops"; + return false; + } + + // Check if both operations are in the same block. + if (target->getBlock() != source->getBlock()) { + diag << "target and source are not in the same block"; + return false; + } + + // Check if fusion will violate dominance. + DominanceInfo domInfo(source); + if (target->isBeforeInBlock(source)) { + // Since `target` is before `source`, all users of results of `target` + // need to be dominated by `source`. + for (Operation *user : target->getUsers()) { + if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { + diag << "user of results of target should " + "be properly dominated by source"; + return false; + } + } + } else { + // Since `target` is after `source`, all values used by `target` need + // to dominate `source`. + + // Check if operands of `target` are dominated by `source`. + for (Value operand : target->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + // Operands without defining operations are block arguments. When `target` + // and `source` occur in the same block, these operands dominate `source`. + if (!operandOp) + continue; + + // Operand's defining operation should properly dominate `source`. + if (!domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + diag << "operands of target should be properly dominated by source"; + return false; + } + } + + // Check if values used by `target` are dominated by `source`. + bool failed = false; + OpOperand *failedValue = nullptr; + visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { + Operation *operandOp = operand->get().getDefiningOp(); + if (operandOp && !domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + // `operand` is not an argument of an enclosing block and the defining + // op of `operand` is outside `target` but does not dominate `source`. + failed = true; + failedValue = operand; + } + }); + + if (failed) { + diag << "values used inside regions of target should be properly " + "dominated by source"; + diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation"; + return false; + } + } + + return true; +} + +bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, + LoopLikeOpInterface source, + Diagnostic &diag) { + if (target->getName() != source->getName()) { + diag << "target and source must be same loop type"; + return false; + } + + bool iterSpaceEq = + target.getLoopLowerBounds() == source.getLoopLowerBounds() && + target.getLoopUpperBounds() == source.getLoopUpperBounds() && + target.getLoopSteps() == source.getLoopSteps(); + // TODO: Decouple checks on concrete loop types and move this function + // somewhere for general utility for `LoopLikeOpInterface` + if (auto forAllTarget = dyn_cast<scf::ForallOp>(*target)) + iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() == + cast<scf::ForallOp>(*source).getMapping(); + if (!iterSpaceEq) { + diag << "target and source iteration spaces must be equal"; + return false; + } + return isOpSibling(target, source, diag); +} + scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { - unsigned numTargetOuts = target.getNumResults(); - unsigned numSourceOuts = source.getNumResults(); - - // Create fused shared_outs. - SmallVector<Value> fusedOuts; - llvm::append_range(fusedOuts, target.getOutputs()); - llvm::append_range(fusedOuts, source.getOutputs()); - - // Create a new scf.forall op after the source loop. - rewriter.setInsertionPointAfter(source); - scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>( - source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), - source.getMixedStep(), fusedOuts, source.getMapping()); - - // Map control operands. - IRMapping mapping; - mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); - mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); - - // Map shared outs. - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - - // Append everything except the terminator into the fused operation. - rewriter.setInsertionPointToStart(fusedLoop.getBody()); - for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, mapping); - for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, mapping); - - // Fuse the old terminator in_parallel ops into the new one. - scf::InParallelOp targetTerm = target.getTerminator(); - scf::InParallelOp sourceTerm = source.getTerminator(); - scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); - rewriter.setInsertionPointToStart(fusedTerm.getBody()); - for (Operation &op : targetTerm.getYieldingOps()) - rewriter.clone(op, mapping); - for (Operation &op : sourceTerm.getYieldingOps()) - rewriter.clone(op, mapping); - - // Replace old loops by substituting their uses by results of the fused loop. - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused( + target, source, rewriter, + [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) { + // `ForallOp` does not have yields, rather an `InParallelOp` terminator. + return ValueRange{}; + }, + [&](RewriterBase &b, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping) { + auto sourceForall = cast<scf::ForallOp>(source); + auto targetForall = cast<scf::ForallOp>(target); + scf::InParallelOp fusedTerm = targetForall.getTerminator(); + b.setInsertionPointToEnd(fusedTerm.getBody()); + for (Operation &op : sourceForall.getTerminator().getYieldingOps()) + b.clone(op, mapping); + })); + rewriter.replaceOp(source, + fusedLoop.getResults().take_back(source.getNumResults())); return fusedLoop; } @@ -1317,49 +1395,74 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { - unsigned numTargetOuts = target.getNumResults(); - unsigned numSourceOuts = source.getNumResults(); - - // Create fused init_args, with target's init_args before source's init_args. - SmallVector<Value> fusedInitArgs; - llvm::append_range(fusedInitArgs, target.getInitArgs()); - llvm::append_range(fusedInitArgs, source.getInitArgs()); - - // Create a new scf.for op after the source loop (with scf.yield terminator - // (without arguments) only in case its init_args is empty). - rewriter.setInsertionPointAfter(source); - scf::ForOp fusedLoop = rewriter.create<scf::ForOp>( - source.getLoc(), source.getLowerBound(), source.getUpperBound(), - source.getStep(), fusedInitArgs); - - // Map original induction variables and operands to those of the fused loop. - IRMapping mapping; - mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - - // Merge target's body into the new (fused) for loop and then source's body. - rewriter.setInsertionPointToStart(fusedLoop.getBody()); - for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, mapping); - for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, mapping); - - // Build fused yield results by appropriately mapping original yield operands. - SmallVector<Value> yieldResults; - for (Value operand : target.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - for (Value operand : source.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - if (!yieldResults.empty()) - rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults); - - // Replace old loops by substituting their uses by results of the fused loop. - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + scf::ForOp fusedLoop = cast<scf::ForOp>(createFused( + target, source, rewriter, + [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) { + return source.getYieldedValues(); + }, + [&](RewriterBase &b, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping) { + auto targetFor = cast<scf::ForOp>(target); + auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping); + b.replaceOp(targetFor.getBody()->getTerminator(), newTerm); + })); + rewriter.replaceOp(source, + fusedLoop.getResults().take_back(source.getNumResults())); + return fusedLoop; +} + +// TODO: Finish refactoring this a la the above, but likely requires additional +// interface methods. +scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( + scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + Block *block1 = target.getBody(); + Block *block2 = source.getBody(); + auto term1 = cast<scf::ReduceOp>(block1->getTerminator()); + auto term2 = cast<scf::ReduceOp>(block2->getTerminator()); + + ValueRange inits1 = target.getInitVals(); + ValueRange inits2 = source.getInitVals(); + + SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + rewriter.setInsertionPoint(source); + auto fusedLoop = rewriter.create<scf::ParallelOp>( + rewriter.getFusedLoc(target.getLoc(), source.getLoc()), + source.getLowerBound(), source.getUpperBound(), source.getStep(), + newInitVars); + Block *newBlock = fusedLoop.getBody(); + rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(), + newBlock->getArguments()); + rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(), + newBlock->getArguments()); + + ValueRange results = fusedLoop.getResults(); + if (!results.empty()) { + rewriter.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = rewriter.create<scf::ReduceOp>( + rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock, + newRedBlock.begin(), + newRedBlock.getArguments()); + } + } + rewriter.replaceOp(target, results.take_front(inits1.size())); + rewriter.replaceOp(source, results.take_back(inits2.size())); + rewriter.eraseOp(term1); + rewriter.eraseOp(term2); return fusedLoop; } diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp index 1e0e87b..6f0ebec 100644 --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -8,6 +8,8 @@ #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/ADT/DenseSet.h" @@ -113,3 +115,56 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { return success(); } + +LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target, + LoopLikeOpInterface source, + RewriterBase &rewriter, + NewYieldValuesFn newYieldValuesFn, + FuseTerminatorFn fuseTerminatorFn) { + auto targetIterArgs = target.getRegionIterArgs(); + std::optional<SmallVector<Value>> targetInductionVar = + target.getLoopInductionVars(); + SmallVector<Value> targetYieldOperands(target.getYieldedValues()); + auto sourceIterArgs = source.getRegionIterArgs(); + std::optional<SmallVector<Value>> sourceInductionVar = + *source.getLoopInductionVars(); + SmallVector<Value> sourceYieldOperands(source.getYieldedValues()); + auto sourceRegion = source.getLoopRegions().front(); + + FailureOr<LoopLikeOpInterface> maybeFusedLoop = + target.replaceWithAdditionalYields(rewriter, source.getInits(), + /*replaceInitOperandUsesInLoop=*/false, + newYieldValuesFn); + if (failed(maybeFusedLoop)) + llvm_unreachable("failed to replace loop"); + LoopLikeOpInterface fusedLoop = *maybeFusedLoop; + + // Map control operands. + IRMapping mapping; + std::optional<SmallVector<Value>> fusedInductionVar = + fusedLoop.getLoopInductionVars(); + if (fusedInductionVar) { + if (!targetInductionVar || !sourceInductionVar) + llvm_unreachable("expected target and source loops to have induction vars"); + mapping.map(*targetInductionVar, *fusedInductionVar); + mapping.map(*sourceInductionVar, *fusedInductionVar); + } + mapping.map(targetIterArgs, + fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); + mapping.map(targetYieldOperands, + fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); + mapping.map(sourceIterArgs, + fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); + mapping.map(sourceYieldOperands, + fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); + // Append everything except the terminator into the fused operation. + rewriter.setInsertionPoint( + fusedLoop.getLoopRegions().front()->front().getTerminator()); + for (Operation &op : sourceRegion->front().without_terminator()) + rewriter.clone(op, mapping); + + // TODO: Replace with corresponding interface method if added + fuseTerminatorFn(rewriter, source, fusedLoop, mapping); + + return fusedLoop; +} |