aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SCF')
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp47
1 files changed, 44 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2946b53..881e256 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2565,6 +2565,39 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
struct ConditionPropagation : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
+ /// Kind of parent region in the ancestor cache.
+ enum class Parent { Then, Else, None };
+
+ /// Returns the kind of region ("then", "else", or "none") of the
+ /// IfOp that the given region is transitively nested in. Updates
+ /// the cache accordingly.
+ static Parent getParentType(Region *toCheck, IfOp op,
+ DenseMap<Region *, Parent> &cache,
+ Region *endRegion) {
+ SmallVector<Region *> seen;
+ while (toCheck != endRegion) {
+ auto found = cache.find(toCheck);
+ if (found != cache.end())
+ return found->second;
+ seen.push_back(toCheck);
+ if (&op.getThenRegion() == toCheck) {
+ for (Region *region : seen)
+ cache[region] = Parent::Then;
+ return Parent::Then;
+ }
+ if (&op.getElseRegion() == toCheck) {
+ for (Region *region : seen)
+ cache[region] = Parent::Else;
+ return Parent::Else;
+ }
+ toCheck = toCheck->getParentRegion();
+ }
+
+ for (Region *region : seen)
+ cache[region] = Parent::None;
+ return Parent::None;
+ }
+
LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
// Early exit if the condition is constant since replacing a constant
@@ -2580,9 +2613,12 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
Value constantTrue = nullptr;
Value constantFalse = nullptr;
+ DenseMap<Region *, Parent> cache;
for (OpOperand &use :
llvm::make_early_inc_range(op.getCondition().getUses())) {
- if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
+ switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
+ op.getCondition().getParentRegion())) {
+ case Parent::Then: {
changed = true;
if (!constantTrue)
@@ -2591,8 +2627,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
- } else if (op.getElseRegion().isAncestor(
- use.getOwner()->getParentRegion())) {
+ break;
+ }
+ case Parent::Else: {
changed = true;
if (!constantFalse)
@@ -2601,6 +2638,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
+ break;
+ }
+ case Parent::None:
+ break;
}
}