diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF/Utils/Utils.cpp')
-rw-r--r-- | mlir/lib/Dialect/SCF/Utils/Utils.cpp | 99 |
1 files changed, 68 insertions, 31 deletions
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 502d7e1..914aeb4 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, unsigned numTargetOuts = target.getNumResults(); unsigned numSourceOuts = source.getNumResults(); - OperandRange targetOuts = target.getOutputs(); - OperandRange sourceOuts = source.getOutputs(); - // Create fused shared_outs. SmallVector<Value> fusedOuts; - fusedOuts.reserve(numTargetOuts + numSourceOuts); - fusedOuts.append(targetOuts.begin(), targetOuts.end()); - fusedOuts.append(sourceOuts.begin(), sourceOuts.end()); + llvm::append_range(fusedOuts, target.getOutputs()); + llvm::append_range(fusedOuts, source.getOutputs()); - // Create a new scf::forall op after the source loop. + // Create a new scf.forall op after the source loop. rewriter.setInsertionPointAfter(source); scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>( source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), source.getMixedStep(), fusedOuts, source.getMapping()); // Map control operands. - IRMapping fusedMapping; - fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); - fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); + IRMapping mapping; + mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); + mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); // Map shared outs. - fusedMapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().slice(0, numTargetOuts)); - fusedMapping.map( - source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts)); + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); // Append everything except the terminator into the fused operation. rewriter.setInsertionPointToStart(fusedLoop.getBody()); for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, fusedMapping); + rewriter.clone(op, mapping); for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, fusedMapping); + rewriter.clone(op, mapping); // Fuse the old terminator in_parallel ops into the new one. scf::InParallelOp targetTerm = target.getTerminator(); scf::InParallelOp sourceTerm = source.getTerminator(); scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); - rewriter.setInsertionPointToStart(fusedTerm.getBody()); for (Operation &op : targetTerm.getYieldingOps()) - rewriter.clone(op, fusedMapping); + rewriter.clone(op, mapping); for (Operation &op : sourceTerm.getYieldingOps()) - rewriter.clone(op, fusedMapping); - - // Replace all uses of the old loops with the fused loop. - rewriter.replaceAllUsesWith(target.getResults(), - fusedLoop.getResults().slice(0, numTargetOuts)); - rewriter.replaceAllUsesWith( - source.getResults(), - fusedLoop.getResults().slice(numTargetOuts, numSourceOuts)); - - // Erase the old loops. - rewriter.eraseOp(target); - rewriter.eraseOp(source); + rewriter.clone(op, mapping); + + // Replace old loops by substituting their uses by results of the fused loop. + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + + return fusedLoop; +} + +scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, + scf::ForOp source, + RewriterBase &rewriter) { + unsigned numTargetOuts = target.getNumResults(); + unsigned numSourceOuts = source.getNumResults(); + + // Create fused init_args, with target's init_args before source's init_args. + SmallVector<Value> fusedInitArgs; + llvm::append_range(fusedInitArgs, target.getInitArgs()); + llvm::append_range(fusedInitArgs, source.getInitArgs()); + + // Create a new scf.for op after the source loop (with scf.yield terminator + // (without arguments) only in case its init_args is empty). + rewriter.setInsertionPointAfter(source); + scf::ForOp fusedLoop = rewriter.create<scf::ForOp>( + source.getLoc(), source.getLowerBound(), source.getUpperBound(), + source.getStep(), fusedInitArgs); + + // Map original induction variables and operands to those of the fused loop. + IRMapping mapping; + mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); + + // Merge target's body into the new (fused) for loop and then source's body. + rewriter.setInsertionPointToStart(fusedLoop.getBody()); + for (Operation &op : target.getBody()->without_terminator()) + rewriter.clone(op, mapping); + for (Operation &op : source.getBody()->without_terminator()) + rewriter.clone(op, mapping); + + // Build fused yield results by appropriately mapping original yield operands. + SmallVector<Value> yieldResults; + for (Value operand : target.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + for (Value operand : source.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + if (!yieldResults.empty()) + rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults); + + // Replace old loops by substituting their uses by results of the fused loop. + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); return fusedLoop; } |