diff options
Diffstat (limited to 'mlir/lib/Dialect/Affine/IR/AffineOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 171 |
1 files changed, 80 insertions, 91 deletions
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 749e2ba..e0a53cd 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2600,6 +2600,65 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) { return success(folded); } +/// Returns constant trip count in trivial cases. +static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) { + int64_t step = forOp.getStepAsInt(); + if (!forOp.hasConstantBounds() || step <= 0) + return std::nullopt; + int64_t lb = forOp.getConstantLowerBound(); + int64_t ub = forOp.getConstantUpperBound(); + return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step; +} + +/// Fold the empty loop. +static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) { + if (!llvm::hasSingleElement(*forOp.getBody())) + return {}; + if (forOp.getNumResults() == 0) + return {}; + std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp); + if (tripCount == 0) { + // The initial values of the iteration arguments would be the op's + // results. + return forOp.getInits(); + } + SmallVector<Value, 4> replacements; + auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator()); + auto iterArgs = forOp.getRegionIterArgs(); + bool hasValDefinedOutsideLoop = false; + bool iterArgsNotInOrder = false; + for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) { + Value val = yieldOp.getOperand(i); + BlockArgument *iterArgIt = llvm::find(iterArgs, val); + // TODO: It should be possible to perform a replacement by computing the + // last value of the IV based on the bounds and the step. + if (val == forOp.getInductionVar()) + return {}; + if (iterArgIt == iterArgs.end()) { + // `val` is defined outside of the loop. + assert(forOp.isDefinedOutsideOfLoop(val) && + "must be defined outside of the loop"); + hasValDefinedOutsideLoop = true; + replacements.push_back(val); + } else { + unsigned pos = std::distance(iterArgs.begin(), iterArgIt); + if (pos != i) + iterArgsNotInOrder = true; + replacements.push_back(forOp.getInits()[pos]); + } + } + // Bail out when the trip count is unknown and the loop returns any value + // defined outside of the loop or any iterArg out of order. + if (!tripCount.has_value() && + (hasValDefinedOutsideLoop || iterArgsNotInOrder)) + return {}; + // Bail out when the loop iterates more than once and it returns any iterArg + // out of order. + if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder) + return {}; + return llvm::to_vector_of<OpFoldResult>(replacements); +} + /// Canonicalize the bounds of the given loop. static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands()); @@ -2631,79 +2690,30 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { return success(); } -namespace { -/// Returns constant trip count in trivial cases. -static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) { - int64_t step = forOp.getStepAsInt(); - if (!forOp.hasConstantBounds() || step <= 0) - return std::nullopt; - int64_t lb = forOp.getConstantLowerBound(); - int64_t ub = forOp.getConstantUpperBound(); - return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step; +/// Returns true if the affine.for has zero iterations in trivial cases. +static bool hasTrivialZeroTripCount(AffineForOp op) { + return getTrivialConstantTripCount(op) == 0; } -/// This is a pattern to fold trivially empty loop bodies. -/// TODO: This should be moved into the folding hook. -struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> { - using OpRewritePattern<AffineForOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(AffineForOp forOp, - PatternRewriter &rewriter) const override { - // Check that the body only contains a yield. - if (!llvm::hasSingleElement(*forOp.getBody())) - return failure(); - if (forOp.getNumResults() == 0) - return success(); - std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp); - if (tripCount == 0) { - // The initial values of the iteration arguments would be the op's - // results. - rewriter.replaceOp(forOp, forOp.getInits()); - return success(); - } - SmallVector<Value, 4> replacements; - auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator()); - auto iterArgs = forOp.getRegionIterArgs(); - bool hasValDefinedOutsideLoop = false; - bool iterArgsNotInOrder = false; - for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) { - Value val = yieldOp.getOperand(i); - auto *iterArgIt = llvm::find(iterArgs, val); - // TODO: It should be possible to perform a replacement by computing the - // last value of the IV based on the bounds and the step. - if (val == forOp.getInductionVar()) - return failure(); - if (iterArgIt == iterArgs.end()) { - // `val` is defined outside of the loop. - assert(forOp.isDefinedOutsideOfLoop(val) && - "must be defined outside of the loop"); - hasValDefinedOutsideLoop = true; - replacements.push_back(val); - } else { - unsigned pos = std::distance(iterArgs.begin(), iterArgIt); - if (pos != i) - iterArgsNotInOrder = true; - replacements.push_back(forOp.getInits()[pos]); - } - } - // Bail out when the trip count is unknown and the loop returns any value - // defined outside of the loop or any iterArg out of order. - if (!tripCount.has_value() && - (hasValDefinedOutsideLoop || iterArgsNotInOrder)) - return failure(); - // Bail out when the loop iterates more than once and it returns any iterArg - // out of order. - if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder) - return failure(); - rewriter.replaceOp(forOp, replacements); - return success(); +LogicalResult AffineForOp::fold(FoldAdaptor adaptor, + SmallVectorImpl<OpFoldResult> &results) { + bool folded = succeeded(foldLoopBounds(*this)); + folded |= succeeded(canonicalizeLoopBounds(*this)); + if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) { + // The initial values of the loop-carried variables (iter_args) are the + // results of the op. But this must be avoided for an affine.for op that + // does not return any results. Since ops that do not return results cannot + // be folded away, we would enter an infinite loop of folds on the same + // affine.for op. + results.assign(getInits().begin(), getInits().end()); + folded = true; } -}; -} // namespace - -void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add<AffineForEmptyLoopFolder>(context); + SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this); + if (!foldResults.empty()) { + results.assign(foldResults); + folded = true; + } + return success(folded); } OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { @@ -2746,27 +2756,6 @@ void AffineForOp::getSuccessorRegions( regions.push_back(RegionSuccessor(getResults())); } -/// Returns true if the affine.for has zero iterations in trivial cases. -static bool hasTrivialZeroTripCount(AffineForOp op) { - return getTrivialConstantTripCount(op) == 0; -} - -LogicalResult AffineForOp::fold(FoldAdaptor adaptor, - SmallVectorImpl<OpFoldResult> &results) { - bool folded = succeeded(foldLoopBounds(*this)); - folded |= succeeded(canonicalizeLoopBounds(*this)); - if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) { - // The initial values of the loop-carried variables (iter_args) are the - // results of the op. But this must be avoided for an affine.for op that - // does not return any results. Since ops that do not return results cannot - // be folded away, we would enter an infinite loop of folds on the same - // affine.for op. - results.assign(getInits().begin(), getInits().end()); - folded = true; - } - return success(folded); -} - AffineBound AffineForOp::getLowerBound() { return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap()); } |