diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF/IR/SCF.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 47 | 
1 files changed, 44 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 2946b53..881e256 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2565,6 +2565,39 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {  struct ConditionPropagation : public OpRewritePattern<IfOp> {    using OpRewritePattern<IfOp>::OpRewritePattern; +  /// Kind of parent region in the ancestor cache. +  enum class Parent { Then, Else, None }; + +  /// Returns the kind of region ("then", "else", or "none") of the +  /// IfOp that the given region is transitively nested in. Updates +  /// the cache accordingly. +  static Parent getParentType(Region *toCheck, IfOp op, +                              DenseMap<Region *, Parent> &cache, +                              Region *endRegion) { +    SmallVector<Region *> seen; +    while (toCheck != endRegion) { +      auto found = cache.find(toCheck); +      if (found != cache.end()) +        return found->second; +      seen.push_back(toCheck); +      if (&op.getThenRegion() == toCheck) { +        for (Region *region : seen) +          cache[region] = Parent::Then; +        return Parent::Then; +      } +      if (&op.getElseRegion() == toCheck) { +        for (Region *region : seen) +          cache[region] = Parent::Else; +        return Parent::Else; +      } +      toCheck = toCheck->getParentRegion(); +    } + +    for (Region *region : seen) +      cache[region] = Parent::None; +    return Parent::None; +  } +    LogicalResult matchAndRewrite(IfOp op,                                  PatternRewriter &rewriter) const override {      // Early exit if the condition is constant since replacing a constant @@ -2580,9 +2613,12 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {      Value constantTrue = nullptr;      Value constantFalse = nullptr; +    DenseMap<Region *, Parent> cache;      for (OpOperand &use :           llvm::make_early_inc_range(op.getCondition().getUses())) { -      if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) { +      switch (getParentType(use.getOwner()->getParentRegion(), op, cache, +                            op.getCondition().getParentRegion())) { +      case Parent::Then: {          changed = true;          if (!constantTrue) @@ -2591,8 +2627,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {          rewriter.modifyOpInPlace(use.getOwner(),                                   [&]() { use.set(constantTrue); }); -      } else if (op.getElseRegion().isAncestor( -                     use.getOwner()->getParentRegion())) { +        break; +      } +      case Parent::Else: {          changed = true;          if (!constantFalse) @@ -2601,6 +2638,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {          rewriter.modifyOpInPlace(use.getOwner(),                                   [&]() { use.set(constantFalse); }); +        break; +      } +      case Parent::None: +        break;        }      }  | 
