aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorsrcarroll <50210727+srcarroll@users.noreply.github.com>2024-07-02 11:12:51 -0500
committerGitHub <noreply@github.com>2024-07-02 11:12:51 -0500
commit6820b0871807abff07df118659e0de2ca741cb0b (patch)
treee6d4b4637118cd50e3cf4dfc0a1a421e72ca1cc8 /mlir/lib
parent6c3897d90eda4c39789ac9f4efa51db46734a249 (diff)
downloadllvm-6820b0871807abff07df118659e0de2ca741cb0b.zip
llvm-6820b0871807abff07df118659e0de2ca741cb0b.tar.gz
llvm-6820b0871807abff07df118659e0de2ca741cb0b.tar.bz2
Refactor LoopFuseSiblingOp and support parallel fusion (#94391)
This patch refactors code related to `LoopFuseSiblingOp` transform in attempt to reduce duplicate common code. The aim is to refactor as much as possible to a functions on `LoopLikeOpInterface`s, but this is still a work in progress. A full refactor will require more additions to the `LoopLikeOpInterface`. In addition, `scf.parallel` fusion support has been added.
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp38
-rw-r--r--mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp140
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp80
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp279
-rw-r--r--mlir/lib/Interfaces/LoopLikeInterface.cpp55
5 files changed, 311 insertions, 281 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 907d7f7..cb15e0e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -618,6 +618,44 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
+FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
+ RewriterBase &rewriter, ValueRange newInitOperands,
+ bool replaceInitOperandUsesInLoop,
+ const NewYieldValuesFn &newYieldValuesFn) {
+ // Create a new loop before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(getOperation());
+ SmallVector<Value> inits(getOutputs());
+ llvm::append_range(inits, newInitOperands);
+ scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
+ getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
+ inits, getMapping(),
+ /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
+
+ // Move the loop body to the new op.
+ rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+ newLoop.getBody()->getArguments().take_front(
+ getBody()->getNumArguments()));
+
+ if (replaceInitOperandUsesInLoop) {
+ // Replace all uses of `newInitOperands` with the corresponding basic block
+ // arguments.
+ for (auto &&[newOperand, oldOperand] :
+ llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
+ newInitOperands.size()))) {
+ rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return newLoop->isProperAncestor(user);
+ });
+ }
+ }
+
+ // Replace the old loop.
+ rewriter.replaceOp(getOperation(),
+ newLoop->getResults().take_front(getNumResults()));
+ return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
/// Promotes the loop body of a forallOp to its containing block if it can be
/// determined that the loop has a single iteration.
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 56ff270..41834fe 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp,
return 1;
};
- std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
- std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
+ std::optional<int64_t> ubConstant =
+ getConstantIntValue(forOp.getUpperBound());
+ std::optional<int64_t> lbConstant =
+ getConstantIntValue(forOp.getLowerBound());
DenseMap<Operation *, unsigned> opCycles;
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
for (Operation &op : forOp.getBody()->getOperations()) {
@@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects(
// LoopFuseSiblingOp
//===----------------------------------------------------------------------===//
-/// Check if `target` and `source` are siblings, in the context that `target`
-/// is being fused into `source`.
-///
-/// This is a simple check that just checks if both operations are in the same
-/// block and some checks to ensure that the fused IR does not violate
-/// dominance.
-static DiagnosedSilenceableFailure isOpSibling(Operation *target,
- Operation *source) {
- // Check if both operations are same.
- if (target == source)
- return emitSilenceableFailure(source)
- << "target and source need to be different loops";
-
- // Check if both operations are in the same block.
- if (target->getBlock() != source->getBlock())
- return emitSilenceableFailure(source)
- << "target and source are not in the same block";
-
- // Check if fusion will violate dominance.
- DominanceInfo domInfo(source);
- if (target->isBeforeInBlock(source)) {
- // Since `target` is before `source`, all users of results of `target`
- // need to be dominated by `source`.
- for (Operation *user : target->getUsers()) {
- if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
- return emitSilenceableFailure(target)
- << "user of results of target should be properly dominated by "
- "source";
- }
- }
- } else {
- // Since `target` is after `source`, all values used by `target` need
- // to dominate `source`.
-
- // Check if operands of `target` are dominated by `source`.
- for (Value operand : target->getOperands()) {
- Operation *operandOp = operand.getDefiningOp();
- // Operands without defining operations are block arguments. When `target`
- // and `source` occur in the same block, these operands dominate `source`.
- if (!operandOp)
- continue;
-
- // Operand's defining operation should properly dominate `source`.
- if (!domInfo.properlyDominates(operandOp, source,
- /*enclosingOpOk=*/false))
- return emitSilenceableFailure(target)
- << "operands of target should be properly dominated by source";
- }
-
- // Check if values used by `target` are dominated by `source`.
- bool failed = false;
- OpOperand *failedValue = nullptr;
- visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
- Operation *operandOp = operand->get().getDefiningOp();
- if (operandOp && !domInfo.properlyDominates(operandOp, source,
- /*enclosingOpOk=*/false)) {
- // `operand` is not an argument of an enclosing block and the defining
- // op of `operand` is outside `target` but does not dominate `source`.
- failed = true;
- failedValue = operand;
- }
- });
-
- if (failed)
- return emitSilenceableFailure(failedValue->getOwner())
- << "values used inside regions of target should be properly "
- "dominated by source";
- }
-
- return DiagnosedSilenceableFailure::success();
-}
-
-/// Check if `target` scf.forall can be fused into `source` scf.forall.
-///
-/// This simply checks if both loops have the same bounds, steps and mapping.
-/// No attempt is made at checking that the side effects of `target` and
-/// `source` are independent of each other.
-static bool isForallWithIdenticalConfiguration(Operation *target,
- Operation *source) {
- auto targetOp = dyn_cast<scf::ForallOp>(target);
- auto sourceOp = dyn_cast<scf::ForallOp>(source);
- if (!targetOp || !sourceOp)
- return false;
-
- return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
- targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
- targetOp.getMixedStep() == sourceOp.getMixedStep() &&
- targetOp.getMapping() == sourceOp.getMapping();
-}
-
-/// Check if `target` scf.for can be fused into `source` scf.for.
-///
-/// This simply checks if both loops have the same bounds and steps. No attempt
-/// is made at checking that the side effects of `target` and `source` are
-/// independent of each other.
-static bool isForWithIdenticalConfiguration(Operation *target,
- Operation *source) {
- auto targetOp = dyn_cast<scf::ForOp>(target);
- auto sourceOp = dyn_cast<scf::ForOp>(source);
- if (!targetOp || !sourceOp)
- return false;
-
- return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
- targetOp.getUpperBound() == sourceOp.getUpperBound() &&
- targetOp.getStep() == sourceOp.getStep();
-}
-
DiagnosedSilenceableFailure
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
}
- Operation *target = *targetOps.begin();
- Operation *source = *sourceOps.begin();
+ auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
+ auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
+ if (!target || !source)
+ return emitSilenceableFailure(target->getLoc())
+ << "target or source is not a loop op";
- // Check if the target and source are siblings.
- DiagnosedSilenceableFailure diag = isOpSibling(target, source);
- if (!diag.succeeded())
- return diag;
+ // Check if loops can be fused
+ Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
+ if (!mlir::checkFusionStructuralLegality(target, source, diag))
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
Operation *fusedLoop;
- /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
- if (isForWithIdenticalConfiguration(target, source)) {
+ // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
+ // and scf.parallel.
+ if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
- } else if (isForallWithIdenticalConfiguration(target, source)) {
+ } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
+ } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
+ fusedLoop = fuseIndependentSiblingParallelLoops(
+ cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
- << "operations cannot be fused";
+ << "unsupported loop type for fusion";
assert(fusedLoop && "failed to fuse operations");
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 5934d85..b775f98 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
@@ -37,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
return walkResult.wasInterrupted();
}
-/// Verify equal iteration spaces.
-static bool equalIterationSpaces(ParallelOp firstPloop,
- ParallelOp secondPloop) {
- if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
- return false;
-
- auto matchOperands = [&](const OperandRange &lhs,
- const OperandRange &rhs) -> bool {
- // TODO: Extend this to support aliases and equal constants.
- return std::equal(lhs.begin(), lhs.end(), rhs.begin());
- };
- return matchOperands(firstPloop.getLowerBound(),
- secondPloop.getLowerBound()) &&
- matchOperands(firstPloop.getUpperBound(),
- secondPloop.getUpperBound()) &&
- matchOperands(firstPloop.getStep(), secondPloop.getStep());
-}
-
/// Checks if the parallel loops have mixed access to the same buffers. Returns
/// `true` if the first parallel loop writes to the same indices that the second
/// loop reads.
@@ -153,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias) {
+ Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
return !hasNestedParallelOp(firstPloop) &&
!hasNestedParallelOp(secondPloop) &&
- equalIterationSpaces(firstPloop, secondPloop) &&
+ checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
succeeded(verifyDependencies(firstPloop, secondPloop,
firstToSecondPloopIndices, mayAlias));
}
@@ -174,61 +158,9 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
mayAlias))
return;
- DominanceInfo dom;
- // We are fusing first loop into second, make sure there are no users of the
- // first loop results between loops.
- for (Operation *user : firstPloop->getUsers())
- if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
- return;
-
- ValueRange inits1 = firstPloop.getInitVals();
- ValueRange inits2 = secondPloop.getInitVals();
-
- SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
- newInitVars.append(inits2.begin(), inits2.end());
-
- IRRewriter b(builder);
- b.setInsertionPoint(secondPloop);
- auto newSecondPloop = b.create<ParallelOp>(
- secondPloop.getLoc(), secondPloop.getLowerBound(),
- secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
- Block *newBlock = newSecondPloop.getBody();
- auto term1 = cast<ReduceOp>(block1->getTerminator());
- auto term2 = cast<ReduceOp>(block2->getTerminator());
-
- b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
- newBlock->getArguments());
- b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
- newBlock->getArguments());
-
- ValueRange results = newSecondPloop.getResults();
- if (!results.empty()) {
- b.setInsertionPointToEnd(newBlock);
-
- ValueRange reduceArgs1 = term1.getOperands();
- ValueRange reduceArgs2 = term2.getOperands();
- SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
- newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
- auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
- for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
- term1.getReductions(), term2.getReductions()))) {
- Block &oldRedBlock = reg.front();
- Block &newRedBlock = newReduceOp.getReductions()[i].front();
- b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
- newRedBlock.getArguments());
- }
-
- firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
- secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
- }
- term1->erase();
- term2->erase();
- firstPloop.erase();
- secondPloop.erase();
- secondPloop = newSecondPloop;
+ IRRewriter rewriter(builder);
+ secondPloop = mlir::fuseIndependentSiblingParallelLoops(
+ firstPloop, secondPloop, rewriter);
}
void mlir::scf::naivelyFuseParallelOps(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c0ee9d2..abfc9a1 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -1262,54 +1263,131 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
+/// Check if `target` and `source` are siblings, in the context that `target`
+/// is being fused into `source`.
+///
+/// This is a simple check that just checks if both operations are in the same
+/// block and some checks to ensure that the fused IR does not violate
+/// dominance.
+static bool isOpSibling(Operation *target, Operation *source,
+ Diagnostic &diag) {
+ // Check if both operations are same.
+ if (target == source) {
+ diag << "target and source need to be different loops";
+ return false;
+ }
+
+ // Check if both operations are in the same block.
+ if (target->getBlock() != source->getBlock()) {
+ diag << "target and source are not in the same block";
+ return false;
+ }
+
+ // Check if fusion will violate dominance.
+ DominanceInfo domInfo(source);
+ if (target->isBeforeInBlock(source)) {
+ // Since `target` is before `source`, all users of results of `target`
+ // need to be dominated by `source`.
+ for (Operation *user : target->getUsers()) {
+ if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
+ diag << "user of results of target should "
+ "be properly dominated by source";
+ return false;
+ }
+ }
+ } else {
+ // Since `target` is after `source`, all values used by `target` need
+ // to dominate `source`.
+
+ // Check if operands of `target` are dominated by `source`.
+ for (Value operand : target->getOperands()) {
+ Operation *operandOp = operand.getDefiningOp();
+ // Operands without defining operations are block arguments. When `target`
+ // and `source` occur in the same block, these operands dominate `source`.
+ if (!operandOp)
+ continue;
+
+ // Operand's defining operation should properly dominate `source`.
+ if (!domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false)) {
+ diag << "operands of target should be properly dominated by source";
+ return false;
+ }
+ }
+
+ // Check if values used by `target` are dominated by `source`.
+ bool failed = false;
+ OpOperand *failedValue = nullptr;
+ visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
+ Operation *operandOp = operand->get().getDefiningOp();
+ if (operandOp && !domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false)) {
+ // `operand` is not an argument of an enclosing block and the defining
+ // op of `operand` is outside `target` but does not dominate `source`.
+ failed = true;
+ failedValue = operand;
+ }
+ });
+
+ if (failed) {
+ diag << "values used inside regions of target should be properly "
+ "dominated by source";
+ diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
+ LoopLikeOpInterface source,
+ Diagnostic &diag) {
+ if (target->getName() != source->getName()) {
+ diag << "target and source must be same loop type";
+ return false;
+ }
+
+ bool iterSpaceEq =
+ target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+ target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+ target.getLoopSteps() == source.getLoopSteps();
+ // TODO: Decouple checks on concrete loop types and move this function
+ // somewhere for general utility for `LoopLikeOpInterface`
+ if (auto forAllTarget = dyn_cast<scf::ForallOp>(*target))
+ iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() ==
+ cast<scf::ForallOp>(*source).getMapping();
+ if (!iterSpaceEq) {
+ diag << "target and source iteration spaces must be equal";
+ return false;
+ }
+ return isOpSibling(target, source, diag);
+}
+
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter) {
- unsigned numTargetOuts = target.getNumResults();
- unsigned numSourceOuts = source.getNumResults();
-
- // Create fused shared_outs.
- SmallVector<Value> fusedOuts;
- llvm::append_range(fusedOuts, target.getOutputs());
- llvm::append_range(fusedOuts, source.getOutputs());
-
- // Create a new scf.forall op after the source loop.
- rewriter.setInsertionPointAfter(source);
- scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
- source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
- source.getMixedStep(), fusedOuts, source.getMapping());
-
- // Map control operands.
- IRMapping mapping;
- mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
- mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
-
- // Map shared outs.
- mapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
- mapping.map(source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
- // Append everything except the terminator into the fused operation.
- rewriter.setInsertionPointToStart(fusedLoop.getBody());
- for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, mapping);
- for (Operation &op : source.getBody()->without_terminator())
- rewriter.clone(op, mapping);
-
- // Fuse the old terminator in_parallel ops into the new one.
- scf::InParallelOp targetTerm = target.getTerminator();
- scf::InParallelOp sourceTerm = source.getTerminator();
- scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
- rewriter.setInsertionPointToStart(fusedTerm.getBody());
- for (Operation &op : targetTerm.getYieldingOps())
- rewriter.clone(op, mapping);
- for (Operation &op : sourceTerm.getYieldingOps())
- rewriter.clone(op, mapping);
-
- // Replace old loops by substituting their uses by results of the fused loop.
- rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
- rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+ scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
+ target, source, rewriter,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+ // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
+ return ValueRange{};
+ },
+ [&](RewriterBase &b, LoopLikeOpInterface source,
+ LoopLikeOpInterface &target, IRMapping mapping) {
+ auto sourceForall = cast<scf::ForallOp>(source);
+ auto targetForall = cast<scf::ForallOp>(target);
+ scf::InParallelOp fusedTerm = targetForall.getTerminator();
+ b.setInsertionPointToEnd(fusedTerm.getBody());
+ for (Operation &op : sourceForall.getTerminator().getYieldingOps())
+ b.clone(op, mapping);
+ }));
+ rewriter.replaceOp(source,
+ fusedLoop.getResults().take_back(source.getNumResults()));
return fusedLoop;
}
@@ -1317,49 +1395,74 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
- unsigned numTargetOuts = target.getNumResults();
- unsigned numSourceOuts = source.getNumResults();
-
- // Create fused init_args, with target's init_args before source's init_args.
- SmallVector<Value> fusedInitArgs;
- llvm::append_range(fusedInitArgs, target.getInitArgs());
- llvm::append_range(fusedInitArgs, source.getInitArgs());
-
- // Create a new scf.for op after the source loop (with scf.yield terminator
- // (without arguments) only in case its init_args is empty).
- rewriter.setInsertionPointAfter(source);
- scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
- source.getLoc(), source.getLowerBound(), source.getUpperBound(),
- source.getStep(), fusedInitArgs);
-
- // Map original induction variables and operands to those of the fused loop.
- IRMapping mapping;
- mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
- mapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
- mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
- mapping.map(source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
- // Merge target's body into the new (fused) for loop and then source's body.
- rewriter.setInsertionPointToStart(fusedLoop.getBody());
- for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, mapping);
- for (Operation &op : source.getBody()->without_terminator())
- rewriter.clone(op, mapping);
-
- // Build fused yield results by appropriately mapping original yield operands.
- SmallVector<Value> yieldResults;
- for (Value operand : target.getBody()->getTerminator()->getOperands())
- yieldResults.push_back(mapping.lookupOrDefault(operand));
- for (Value operand : source.getBody()->getTerminator()->getOperands())
- yieldResults.push_back(mapping.lookupOrDefault(operand));
- if (!yieldResults.empty())
- rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
-
- // Replace old loops by substituting their uses by results of the fused loop.
- rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
- rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+ scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
+ target, source, rewriter,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+ return source.getYieldedValues();
+ },
+ [&](RewriterBase &b, LoopLikeOpInterface source,
+ LoopLikeOpInterface &target, IRMapping mapping) {
+ auto targetFor = cast<scf::ForOp>(target);
+ auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
+ b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
+ }));
+ rewriter.replaceOp(source,
+ fusedLoop.getResults().take_back(source.getNumResults()));
+ return fusedLoop;
+}
+
+// TODO: Finish refactoring this a la the above, but likely requires additional
+// interface methods.
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+ scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block *block1 = target.getBody();
+ Block *block2 = source.getBody();
+ auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+ auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+ ValueRange inits1 = target.getInitVals();
+ ValueRange inits2 = source.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ rewriter.setInsertionPoint(source);
+ auto fusedLoop = rewriter.create<scf::ParallelOp>(
+ rewriter.getFusedLoc(target.getLoc(), source.getLoc()),
+ source.getLowerBound(), source.getUpperBound(), source.getStep(),
+ newInitVars);
+ Block *newBlock = fusedLoop.getBody();
+ rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+ rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+
+ ValueRange results = fusedLoop.getResults();
+ if (!results.empty()) {
+ rewriter.setInsertionPointToEnd(newBlock);
+
+ ValueRange reduceArgs1 = term1.getOperands();
+ ValueRange reduceArgs2 = term2.getOperands();
+ SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+ newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+ auto newReduceOp = rewriter.create<scf::ReduceOp>(
+ rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs);
+
+ for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+ term1.getReductions(), term2.getReductions()))) {
+ Block &oldRedBlock = reg.front();
+ Block &newRedBlock = newReduceOp.getReductions()[i].front();
+ rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock,
+ newRedBlock.begin(),
+ newRedBlock.getArguments());
+ }
+ }
+ rewriter.replaceOp(target, results.take_front(inits1.size()));
+ rewriter.replaceOp(source, results.take_back(inits2.size()));
+ rewriter.eraseOp(term1);
+ rewriter.eraseOp(term2);
return fusedLoop;
}
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 1e0e87b..6f0ebec 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -8,6 +8,8 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/DenseSet.h"
@@ -113,3 +115,56 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
return success();
}
+
+LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
+ LoopLikeOpInterface source,
+ RewriterBase &rewriter,
+ NewYieldValuesFn newYieldValuesFn,
+ FuseTerminatorFn fuseTerminatorFn) {
+ auto targetIterArgs = target.getRegionIterArgs();
+ std::optional<SmallVector<Value>> targetInductionVar =
+ target.getLoopInductionVars();
+ SmallVector<Value> targetYieldOperands(target.getYieldedValues());
+ auto sourceIterArgs = source.getRegionIterArgs();
+ std::optional<SmallVector<Value>> sourceInductionVar =
+ *source.getLoopInductionVars();
+ SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
+ auto sourceRegion = source.getLoopRegions().front();
+
+ FailureOr<LoopLikeOpInterface> maybeFusedLoop =
+ target.replaceWithAdditionalYields(rewriter, source.getInits(),
+ /*replaceInitOperandUsesInLoop=*/false,
+ newYieldValuesFn);
+ if (failed(maybeFusedLoop))
+ llvm_unreachable("failed to replace loop");
+ LoopLikeOpInterface fusedLoop = *maybeFusedLoop;
+
+ // Map control operands.
+ IRMapping mapping;
+ std::optional<SmallVector<Value>> fusedInductionVar =
+ fusedLoop.getLoopInductionVars();
+ if (fusedInductionVar) {
+ if (!targetInductionVar || !sourceInductionVar)
+ llvm_unreachable("expected target and source loops to have induction vars");
+ mapping.map(*targetInductionVar, *fusedInductionVar);
+ mapping.map(*sourceInductionVar, *fusedInductionVar);
+ }
+ mapping.map(targetIterArgs,
+ fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
+ mapping.map(targetYieldOperands,
+ fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
+ mapping.map(sourceIterArgs,
+ fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
+ mapping.map(sourceYieldOperands,
+ fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
+ // Append everything except the terminator into the fused operation.
+ rewriter.setInsertionPoint(
+ fusedLoop.getLoopRegions().front()->front().getTerminator());
+ for (Operation &op : sourceRegion->front().without_terminator())
+ rewriter.clone(op, mapping);
+
+ // TODO: Replace with corresponding interface method if added
+ fuseTerminatorFn(rewriter, source, fusedLoop, mapping);
+
+ return fusedLoop;
+}