aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/Utils/Utils.cpp
diff options
context:
space:
mode:
authorRolf Morel <rolf.morel@huawei.com>2024-03-28 13:13:08 +0000
committerGitHub <noreply@github.com>2024-03-28 14:13:08 +0100
commiteacda36c7dd842cb15c0c954eda74b67d0c73814 (patch)
treefadca8a65aec6793444ce3fa3bdb32614f5f49b6 /mlir/lib/Dialect/SCF/Utils/Utils.cpp
parent91856b34e3eddf157ab4c6ea623483b49d149e62 (diff)
downloadllvm-eacda36c7dd842cb15c0c954eda74b67d0c73814.zip
llvm-eacda36c7dd842cb15c0c954eda74b67d0c73814.tar.gz
llvm-eacda36c7dd842cb15c0c954eda74b67d0c73814.tar.bz2
[SCF][Transform] Add support for scf.for in LoopFuseSibling op (#81495)
Adds support for fusing two scf.for loops occurring in the same block. Uses the rudimentary checks already in place for scf.forall (like the target loop's operands being dominated by the source loop). - Fixes a bug in the dominance check whereby it was checked that values in the target loop themselves dominated the source loop rather than the ops that define these operands. - Renames the LoopFuseSibling op to LoopFuseSiblingOp. - Updates LoopFuseSiblingOp's description. - Adds tests for using LoopFuseSiblingOp on scf.for loops, including one which fails without the fix for the dominance check. - Adds tests checking the different failure modes of the dominance checker. - Adds test for case whereby scf.yield is automatically generated when there are no loop-carried variables.
Diffstat (limited to 'mlir/lib/Dialect/SCF/Utils/Utils.cpp')
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp99
1 files changed, 68 insertions, 31 deletions
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 502d7e1..914aeb4 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();
- OperandRange targetOuts = target.getOutputs();
- OperandRange sourceOuts = source.getOutputs();
-
// Create fused shared_outs.
SmallVector<Value> fusedOuts;
- fusedOuts.reserve(numTargetOuts + numSourceOuts);
- fusedOuts.append(targetOuts.begin(), targetOuts.end());
- fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+ llvm::append_range(fusedOuts, target.getOutputs());
+ llvm::append_range(fusedOuts, source.getOutputs());
- // Create a new scf::forall op after the source loop.
+ // Create a new scf.forall op after the source loop.
rewriter.setInsertionPointAfter(source);
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
source.getMixedStep(), fusedOuts, source.getMapping());
// Map control operands.
- IRMapping fusedMapping;
- fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
- fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+ IRMapping mapping;
+ mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+ mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
// Map shared outs.
- fusedMapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
- fusedMapping.map(
- source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
+ mapping.map(target.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+ mapping.map(source.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
// Append everything except the terminator into the fused operation.
rewriter.setInsertionPointToStart(fusedLoop.getBody());
for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
for (Operation &op : source.getBody()->without_terminator())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
// Fuse the old terminator in_parallel ops into the new one.
scf::InParallelOp targetTerm = target.getTerminator();
scf::InParallelOp sourceTerm = source.getTerminator();
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-
rewriter.setInsertionPointToStart(fusedTerm.getBody());
for (Operation &op : targetTerm.getYieldingOps())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
for (Operation &op : sourceTerm.getYieldingOps())
- rewriter.clone(op, fusedMapping);
-
- // Replace all uses of the old loops with the fused loop.
- rewriter.replaceAllUsesWith(target.getResults(),
- fusedLoop.getResults().slice(0, numTargetOuts));
- rewriter.replaceAllUsesWith(
- source.getResults(),
- fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
-
- // Erase the old loops.
- rewriter.eraseOp(target);
- rewriter.eraseOp(source);
+ rewriter.clone(op, mapping);
+
+ // Replace old loops by substituting their uses by results of the fused loop.
+ rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+ rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+
+ return fusedLoop;
+}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+ scf::ForOp source,
+ RewriterBase &rewriter) {
+ unsigned numTargetOuts = target.getNumResults();
+ unsigned numSourceOuts = source.getNumResults();
+
+ // Create fused init_args, with target's init_args before source's init_args.
+ SmallVector<Value> fusedInitArgs;
+ llvm::append_range(fusedInitArgs, target.getInitArgs());
+ llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+ // Create a new scf.for op after the source loop (with scf.yield terminator
+ // (without arguments) only in case its init_args is empty).
+ rewriter.setInsertionPointAfter(source);
+ scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+ source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+ source.getStep(), fusedInitArgs);
+
+ // Map original induction variables and operands to those of the fused loop.
+ IRMapping mapping;
+ mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+ mapping.map(target.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+ mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+ mapping.map(source.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+ // Merge target's body into the new (fused) for loop and then source's body.
+ rewriter.setInsertionPointToStart(fusedLoop.getBody());
+ for (Operation &op : target.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+ for (Operation &op : source.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+
+ // Build fused yield results by appropriately mapping original yield operands.
+ SmallVector<Value> yieldResults;
+ for (Value operand : target.getBody()->getTerminator()->getOperands())
+ yieldResults.push_back(mapping.lookupOrDefault(operand));
+ for (Value operand : source.getBody()->getTerminator()->getOperands())
+ yieldResults.push_back(mapping.lookupOrDefault(operand));
+ if (!yieldResults.empty())
+ rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+ // Replace old loops by substituting their uses by results of the fused loop.
+ rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+ rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
return fusedLoop;
}