diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopInvariantCodeMotion.cpp')
-rw-r--r-- | mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 49 |
1 files changed, 20 insertions, 29 deletions
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 9ed95a9..ccac42f 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -16,7 +16,6 @@ #include "mlir/IR/Function.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopLikeInterface.h" -#include "mlir/Transforms/SideEffectsInterface.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -27,8 +26,6 @@ using namespace mlir; namespace { -using SideEffecting = SideEffectsInterface::SideEffecting; - /// Loop invariant code motion (LICM) pass. struct LoopInvariantCodeMotion : public OperationPass<LoopInvariantCodeMotion> { public: @@ -41,40 +38,39 @@ public: // - the op has no side-effects. If sideEffecting is Never, sideeffects of this // op and its nested ops are ignored. static bool canBeHoisted(Operation *op, - function_ref<bool(Value)> definedOutside, - SideEffecting sideEffecting, - SideEffectsInterface &interface) { + function_ref<bool(Value)> definedOutside) { // Check that dependencies are defined outside of loop. if (!llvm::all_of(op->getOperands(), definedOutside)) return false; // Check whether this op is side-effect free. If we already know that there // can be no side-effects because the surrounding op has claimed so, we can // (and have to) skip this step. - auto thisOpIsSideEffecting = sideEffecting; - if (thisOpIsSideEffecting != SideEffecting::Never) { - thisOpIsSideEffecting = interface.isSideEffecting(op); - // If the op always has sideeffects, we cannot hoist. - if (thisOpIsSideEffecting == SideEffecting::Always) + if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) { + if (!memInterface.hasNoEffect()) return false; + } else if (!op->hasNoSideEffect() && + !op->hasTrait<OpTrait::HasRecursiveSideEffects>()) { + return false; } + + // If the operation doesn't have side effects and it doesn't recursively + // have side effects, it can always be hoisted. + if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>()) + return true; + // Recurse into the regions for this op and check whether the contained ops // can be hoisted. for (auto ®ion : op->getRegions()) { for (auto &block : region.getBlocks()) { - for (auto &innerOp : block) { - if (innerOp.isKnownTerminator()) - continue; - if (!canBeHoisted(&innerOp, definedOutside, thisOpIsSideEffecting, - interface)) + for (auto &innerOp : block.without_terminator()) + if (!canBeHoisted(&innerOp, definedOutside)) return false; - } } } return true; } -static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike, - SideEffectsInterface &interface) { +static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike) { auto &loopBody = looplike.getLoopBody(); // We use two collections here as we need to preserve the order for insertion @@ -94,9 +90,7 @@ static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike, // rewriting. If the nested regions are loops, they will have been processed. for (auto &block : loopBody) { for (auto &op : block.without_terminator()) { - if (canBeHoisted(&op, isDefinedOutsideOfBody, - mlir::SideEffectsDialectInterface::Recursive, - interface)) { + if (canBeHoisted(&op, isDefinedOutsideOfBody)) { opsToMove.push_back(&op); willBeMovedSet.insert(&op); } @@ -113,16 +107,13 @@ static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike, } // end anonymous namespace void LoopInvariantCodeMotion::runOnOperation() { - SideEffectsInterface interface(&getContext()); // Walk through all loops in a function in innermost-loop-first order. This // way, we first LICM from the inner loop, and place the ops in // the outer loop, which in turn can be further LICM'ed. - getOperation()->walk([&](Operation *op) { - if (auto looplike = dyn_cast<LoopLikeOpInterface>(op)) { - LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n")); - if (failed(moveLoopInvariantCode(looplike, interface))) - signalPassFailure(); - } + getOperation()->walk([&](LoopLikeOpInterface loopLike) { + LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop\n")); + if (failed(moveLoopInvariantCode(loopLike))) + signalPassFailure(); }); } |