aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Affine/IR/AffineOps.cpp')
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp171
1 files changed, 80 insertions, 91 deletions
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 749e2ba..e0a53cd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2600,6 +2600,65 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) {
return success(folded);
}
+/// Returns constant trip count in trivial cases.
+static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
+ int64_t step = forOp.getStepAsInt();
+ if (!forOp.hasConstantBounds() || step <= 0)
+ return std::nullopt;
+ int64_t lb = forOp.getConstantLowerBound();
+ int64_t ub = forOp.getConstantUpperBound();
+ return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+}
+
+/// Fold the empty loop.
+static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
+ if (!llvm::hasSingleElement(*forOp.getBody()))
+ return {};
+ if (forOp.getNumResults() == 0)
+ return {};
+ std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
+ if (tripCount == 0) {
+ // The initial values of the iteration arguments would be the op's
+ // results.
+ return forOp.getInits();
+ }
+ SmallVector<Value, 4> replacements;
+ auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
+ auto iterArgs = forOp.getRegionIterArgs();
+ bool hasValDefinedOutsideLoop = false;
+ bool iterArgsNotInOrder = false;
+ for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
+ Value val = yieldOp.getOperand(i);
+ BlockArgument *iterArgIt = llvm::find(iterArgs, val);
+ // TODO: It should be possible to perform a replacement by computing the
+ // last value of the IV based on the bounds and the step.
+ if (val == forOp.getInductionVar())
+ return {};
+ if (iterArgIt == iterArgs.end()) {
+ // `val` is defined outside of the loop.
+ assert(forOp.isDefinedOutsideOfLoop(val) &&
+ "must be defined outside of the loop");
+ hasValDefinedOutsideLoop = true;
+ replacements.push_back(val);
+ } else {
+ unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
+ if (pos != i)
+ iterArgsNotInOrder = true;
+ replacements.push_back(forOp.getInits()[pos]);
+ }
+ }
+ // Bail out when the trip count is unknown and the loop returns any value
+ // defined outside of the loop or any iterArg out of order.
+ if (!tripCount.has_value() &&
+ (hasValDefinedOutsideLoop || iterArgsNotInOrder))
+ return {};
+ // Bail out when the loop iterates more than once and it returns any iterArg
+ // out of order.
+ if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
+ return {};
+ return llvm::to_vector_of<OpFoldResult>(replacements);
+}
+
/// Canonicalize the bounds of the given loop.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
@@ -2631,79 +2690,30 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
return success();
}
-namespace {
-/// Returns constant trip count in trivial cases.
-static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
- int64_t step = forOp.getStepAsInt();
- if (!forOp.hasConstantBounds() || step <= 0)
- return std::nullopt;
- int64_t lb = forOp.getConstantLowerBound();
- int64_t ub = forOp.getConstantUpperBound();
- return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+/// Returns true if the affine.for has zero iterations in trivial cases.
+static bool hasTrivialZeroTripCount(AffineForOp op) {
+ return getTrivialConstantTripCount(op) == 0;
}
-/// This is a pattern to fold trivially empty loop bodies.
-/// TODO: This should be moved into the folding hook.
-struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- // Check that the body only contains a yield.
- if (!llvm::hasSingleElement(*forOp.getBody()))
- return failure();
- if (forOp.getNumResults() == 0)
- return success();
- std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
- if (tripCount == 0) {
- // The initial values of the iteration arguments would be the op's
- // results.
- rewriter.replaceOp(forOp, forOp.getInits());
- return success();
- }
- SmallVector<Value, 4> replacements;
- auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
- auto iterArgs = forOp.getRegionIterArgs();
- bool hasValDefinedOutsideLoop = false;
- bool iterArgsNotInOrder = false;
- for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
- Value val = yieldOp.getOperand(i);
- auto *iterArgIt = llvm::find(iterArgs, val);
- // TODO: It should be possible to perform a replacement by computing the
- // last value of the IV based on the bounds and the step.
- if (val == forOp.getInductionVar())
- return failure();
- if (iterArgIt == iterArgs.end()) {
- // `val` is defined outside of the loop.
- assert(forOp.isDefinedOutsideOfLoop(val) &&
- "must be defined outside of the loop");
- hasValDefinedOutsideLoop = true;
- replacements.push_back(val);
- } else {
- unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
- if (pos != i)
- iterArgsNotInOrder = true;
- replacements.push_back(forOp.getInits()[pos]);
- }
- }
- // Bail out when the trip count is unknown and the loop returns any value
- // defined outside of the loop or any iterArg out of order.
- if (!tripCount.has_value() &&
- (hasValDefinedOutsideLoop || iterArgsNotInOrder))
- return failure();
- // Bail out when the loop iterates more than once and it returns any iterArg
- // out of order.
- if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
- return failure();
- rewriter.replaceOp(forOp, replacements);
- return success();
+LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ bool folded = succeeded(foldLoopBounds(*this));
+ folded |= succeeded(canonicalizeLoopBounds(*this));
+ if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
+ // The initial values of the loop-carried variables (iter_args) are the
+ // results of the op. But this must be avoided for an affine.for op that
+ // does not return any results. Since ops that do not return results cannot
+ // be folded away, we would enter an infinite loop of folds on the same
+ // affine.for op.
+ results.assign(getInits().begin(), getInits().end());
+ folded = true;
}
-};
-} // namespace
-
-void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<AffineForEmptyLoopFolder>(context);
+ SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this);
+ if (!foldResults.empty()) {
+ results.assign(foldResults);
+ folded = true;
+ }
+ return success(folded);
}
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
@@ -2746,27 +2756,6 @@ void AffineForOp::getSuccessorRegions(
regions.push_back(RegionSuccessor(getResults()));
}
-/// Returns true if the affine.for has zero iterations in trivial cases.
-static bool hasTrivialZeroTripCount(AffineForOp op) {
- return getTrivialConstantTripCount(op) == 0;
-}
-
-LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
- SmallVectorImpl<OpFoldResult> &results) {
- bool folded = succeeded(foldLoopBounds(*this));
- folded |= succeeded(canonicalizeLoopBounds(*this));
- if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
- // The initial values of the loop-carried variables (iter_args) are the
- // results of the op. But this must be avoided for an affine.for op that
- // does not return any results. Since ops that do not return results cannot
- // be folded away, we would enter an infinite loop of folds on the same
- // affine.for op.
- results.assign(getInits().begin(), getInits().end());
- folded = true;
- }
- return success(folded);
-}
-
AffineBound AffineForOp::getLowerBound() {
return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
}