aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp')
-rw-r--r--mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
index ced0033..2470380 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
@@ -42,11 +42,10 @@ struct TestLoopUnrollingPass
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
unsigned loopDepthParam,
- bool annotateLoopParam, bool unrollFullParam) {
+ bool annotateLoopParam) {
unrollFactor = unrollFactorParam;
loopDepth = loopDepthParam;
annotateLoop = annotateLoopParam;
- unrollFull = unrollFactorParam;
}
void getDependentDialects(DialectRegistry &registry) const override {
@@ -54,6 +53,12 @@ struct TestLoopUnrollingPass
}
void runOnOperation() override {
+ if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) {
+ emitError(UnknownLoc::get(&getContext()),
+ "Invalid option: 'unroll-factor' should be greater than 0 or "
+ "equal to -1");
+ return signalPassFailure();
+ }
SmallVector<scf::ForOp, 4> loops;
getOperation()->walk([&](scf::ForOp forOp) {
if (getNestingDepth(forOp) == loopDepth)
@@ -65,15 +70,15 @@ struct TestLoopUnrollingPass
}
};
for (auto loop : loops) {
- if (unrollFull)
+ if (unrollFactor.getValue() == -1)
(void)loopUnrollFull(loop);
else
(void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
}
}
- Option<uint64_t> unrollFactor{*this, "unroll-factor",
- llvm::cl::desc("Loop unroll factor."),
- llvm::cl::init(1)};
+ Option<int64_t> unrollFactor{*this, "unroll-factor",
+ llvm::cl::desc("Loop unroll factor."),
+ llvm::cl::init(1)};
Option<bool> annotateLoop{*this, "annotate",
llvm::cl::desc("Annotate unrolled iterations."),
llvm::cl::init(false)};
@@ -82,9 +87,6 @@ struct TestLoopUnrollingPass
llvm::cl::init(false)};
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
llvm::cl::init(0)};
- Option<bool> unrollFull{*this, "unroll-full",
- llvm::cl::desc("Full unroll loops."),
- llvm::cl::init(false)};
};
} // namespace