1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
|
//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ForallOp's into SCF.ParallelOps's.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
scf::ForallOp forallOp,
scf::ParallelOp *result) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forallOp);
Location loc = forallOp.getLoc();
if (!forallOp.getOutputs().empty())
return rewriter.notifyMatchFailure(
forallOp,
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
// Convert mixed bounds and steps to SSA values.
SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
SmallVector<Value> steps = forallOp.getStep(rewriter);
// Create empty scf.parallel op.
auto parallelOp = scf::ParallelOp::create(rewriter, loc, lbs, ubs, steps);
rewriter.eraseBlock(¶llelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());
// If the mapping attribute is present, propagate to the new parallelOp.
if (forallOp.getMapping())
parallelOp->setAttr("mapping", *forallOp.getMapping());
// Erase the scf.forall op.
rewriter.replaceOp(forallOp, parallelOp);
if (result)
*result = parallelOp;
return success();
}
namespace {
struct ForallToParallelLoop final
: public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
void runOnOperation() override {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());
parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
return signalPassFailure();
}
});
}
};
} // namespace
std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
return std::make_unique<ForallToParallelLoop>();
}
|