aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect')
-rw-r--r--mlir/test/lib/Dialect/SCF/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp85
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpDefs.cpp26
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td2
4 files changed, 101 insertions, 13 deletions
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 791c2e6..d2f97e8 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -2,6 +2,7 @@
add_mlir_library(MLIRSCFTestPasses
TestLoopParametricTiling.cpp
TestLoopUnrolling.cpp
+ TestParallelLoopUnrolling.cpp
TestSCFUtils.cpp
TestSCFWrapInZeroTripCheck.cpp
TestUpliftWhileToFor.cpp
diff --git a/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp
new file mode 100644
index 0000000..77a22a18
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp
@@ -0,0 +1,85 @@
+//=== TestParallelLoopUnrolling.cpp - loop unrolling test pass ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to unroll loops by a specified unroll factor.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+static unsigned getNestingDepth(Operation *op) {
+ Operation *currOp = op;
+ unsigned depth = 0;
+ while ((currOp = currOp->getParentOp())) {
+ if (isa<scf::ParallelOp>(currOp))
+ depth++;
+ }
+ return depth;
+}
+
+struct TestParallelLoopUnrollingPass
+ : public PassWrapper<TestParallelLoopUnrollingPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestParallelLoopUnrollingPass)
+
+ StringRef getArgument() const final { return "test-parallel-loop-unrolling"; }
+ StringRef getDescription() const final {
+ return "Tests parallel loop unrolling transformation";
+ }
+ TestParallelLoopUnrollingPass() = default;
+ TestParallelLoopUnrollingPass(const TestParallelLoopUnrollingPass &) {}
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect>();
+ }
+
+ void runOnOperation() override {
+ SmallVector<scf::ParallelOp, 4> loops;
+ getOperation()->walk([&](scf::ParallelOp parLoop) {
+ if (getNestingDepth(parLoop) == loopDepth)
+ loops.push_back(parLoop);
+ });
+ auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
+ if (annotateLoop) {
+ op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
+ }
+ };
+ PatternRewriter rewriter(getOperation()->getContext());
+ for (auto loop : loops) {
+ (void)parallelLoopUnrollByFactors(loop, unrollFactors, rewriter,
+ annotateFn);
+ }
+ }
+
+ ListOption<uint64_t> unrollFactors{
+ *this, "unroll-factors",
+ llvm::cl::desc(
+ "Unroll factors for each parallel loop dim. If fewer factors than "
+ "loop dims are provided, they are applied to the inner dims.")};
+ Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
+ llvm::cl::init(0)};
+ Option<bool> annotateLoop{*this, "annotate",
+ llvm::cl::desc("Annotate unrolled iterations."),
+ llvm::cl::init(false)};
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestParallelLoopUnrollingPass() {
+ PassRegistration<TestParallelLoopUnrollingPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b211e24..4d4ec02 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -633,8 +633,9 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()},
+ successor.getSuccessor()) &&
"invalid region index");
return getOperands();
}
@@ -643,10 +644,11 @@ void RegionIfOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// We always branch to the join region.
if (!point.isParent()) {
- if (point != getJoinRegion())
+ if (point.getTerminatorPredecessorOrNull()->getParentRegion() !=
+ &getJoinRegion())
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -673,7 +675,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
if (point.isParent())
regions.emplace_back(&getRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
}
void AnyCondOp::getRegionInvocationBounds(
@@ -1107,11 +1109,11 @@ void LoopBlockOp::getSuccessorRegions(
if (point.isParent())
return;
- regions.emplace_back((*this)->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody());
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBody());
return MutableOperandRange(getInitMutable());
}
@@ -1120,8 +1122,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
//===----------------------------------------------------------------------===//
MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- if (point.isParent())
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) {
+ if (successor.isParent())
return getExitArgMutable();
return getNextIterArgMutable();
}
@@ -1213,7 +1215,7 @@ void TestStoreWithARegion::getSuccessorRegions(
if (point.isParent())
regions.emplace_back(&getBody(), getBody().front().getArguments());
else
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
//===----------------------------------------------------------------------===//
@@ -1227,7 +1229,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions(
// enter the body.
regions.emplace_back(
RegionSuccessor(&getBody(), getBody().front().getArguments()));
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 05a33cf..a3430ba 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
NoTerminator,
- DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]>
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>
]> {
let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);