diff options
Diffstat (limited to 'mlir/test/lib/Dialect')
| -rw-r--r-- | mlir/test/lib/Dialect/SCF/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp | 85 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 26 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestOps.td | 2 |
4 files changed, 101 insertions, 13 deletions
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt index 791c2e6..d2f97e8 100644 --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRSCFTestPasses TestLoopParametricTiling.cpp TestLoopUnrolling.cpp + TestParallelLoopUnrolling.cpp TestSCFUtils.cpp TestSCFWrapInZeroTripCheck.cpp TestUpliftWhileToFor.cpp diff --git a/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp new file mode 100644 index 0000000..77a22a18 --- /dev/null +++ b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp @@ -0,0 +1,85 @@ +//=== TestParallelLoopUnrolling.cpp - loop unrolling test pass ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to unroll loops by a specified unroll factor. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +static unsigned getNestingDepth(Operation *op) { + Operation *currOp = op; + unsigned depth = 0; + while ((currOp = currOp->getParentOp())) { + if (isa<scf::ParallelOp>(currOp)) + depth++; + } + return depth; +} + +struct TestParallelLoopUnrollingPass + : public PassWrapper<TestParallelLoopUnrollingPass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestParallelLoopUnrollingPass) + + StringRef getArgument() const final { return "test-parallel-loop-unrolling"; } + StringRef getDescription() const final { + return "Tests parallel loop unrolling transformation"; + } + TestParallelLoopUnrollingPass() = default; + TestParallelLoopUnrollingPass(const TestParallelLoopUnrollingPass &) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<arith::ArithDialect>(); + } + + void runOnOperation() override { + SmallVector<scf::ParallelOp, 4> loops; + getOperation()->walk([&](scf::ParallelOp parLoop) { + if (getNestingDepth(parLoop) == loopDepth) + loops.push_back(parLoop); + }); + auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) { + if (annotateLoop) { + op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i)); + } + }; + PatternRewriter rewriter(getOperation()->getContext()); + for (auto loop : loops) { + (void)parallelLoopUnrollByFactors(loop, unrollFactors, rewriter, + annotateFn); + } + } + + ListOption<uint64_t> unrollFactors{ + *this, "unroll-factors", + llvm::cl::desc( + "Unroll factors for each parallel loop dim. If fewer factors than " + "loop dims are provided, they are applied to the inner dims.")}; + Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."), + llvm::cl::init(0)}; + Option<bool> annotateLoop{*this, "annotate", + llvm::cl::desc("Annotate unrolled iterations."), + llvm::cl::init(false)}; +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestParallelLoopUnrollingPass() { + PassRegistration<TestParallelLoopUnrollingPass>(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index b211e24..4d4ec02 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -633,8 +633,9 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { parser.getCurrentLocation(), result.operands); } -OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && +OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, + successor.getSuccessor()) && "invalid region index"); return getOperands(); } @@ -643,10 +644,11 @@ void RegionIfOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // We always branch to the join region. if (!point.isParent()) { - if (point != getJoinRegion()) + if (point.getTerminatorPredecessorOrNull()->getParentRegion() != + &getJoinRegion()) regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); else - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -673,7 +675,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, if (point.isParent()) regions.emplace_back(&getRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } void AnyCondOp::getRegionInvocationBounds( @@ -1107,11 +1109,11 @@ void LoopBlockOp::getSuccessorRegions( if (point.isParent()) return; - regions.emplace_back((*this)->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } -OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody()); +OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBody()); return MutableOperandRange(getInitMutable()); } @@ -1120,8 +1122,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { //===----------------------------------------------------------------------===// MutableOperandRange -LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { - if (point.isParent()) +LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) { + if (successor.isParent()) return getExitArgMutable(); return getNextIterArgMutable(); } @@ -1213,7 +1215,7 @@ void TestStoreWithARegion::getSuccessorRegions( if (point.isParent()) regions.emplace_back(&getBody(), getBody().front().getArguments()); else - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } //===----------------------------------------------------------------------===// @@ -1227,7 +1229,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions( // enter the body. regions.emplace_back( RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 05a33cf..a3430ba 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term", def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [ NoTerminator, - DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]> + DeclareOpInterfaceMethods<RegionBranchOpInterface> ]> { let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases); let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions); |
