diff options
-rw-r--r-- | mlir/include/mlir/IR/Block.h | 4 | ||||
-rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 15 | ||||
-rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp | 1 | ||||
-rw-r--r-- | mlir/lib/IR/Block.cpp | 10 | ||||
-rw-r--r-- | mlir/lib/IR/Dominance.cpp | 18 | ||||
-rw-r--r-- | mlir/lib/IR/PatternMatch.cpp | 28 |
6 files changed, 57 insertions, 19 deletions
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index e486bb62..416e8e5 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -152,6 +152,10 @@ public: Operation &back() { return operations.back(); } Operation &front() { return operations.front(); } + /// Return if the iterator `a` is before `b`. Both iterators must point into + /// this block. + bool isBeforeInBlock(iterator a, iterator b); + /// Returns 'op' if 'op' lies in this block, or otherwise finds the /// ancestor operation of 'op' that lies in this block. Returns nullptr if /// the latter fails. diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index b5a93a0..fa87d69 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -576,24 +576,39 @@ public: /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. + /// + /// If the current insertion point is before the split point, the insertion + /// point is adjusted to the new block. Block *splitBlock(Block *block, Block::iterator before); /// Unlink this operation from its current block and insert it right before /// `existingOp` which may be in the same or another block in the same /// function. + /// + /// If the insertion point is before the moved operation, the insertion block + /// is adjusted to the block of `existingOp`. void moveOpBefore(Operation *op, Operation *existingOp); /// Unlink this operation from its current block and insert it right before /// `iterator` in the specified block. + /// + /// If the insertion point is before the moved operation, the insertion block + /// is adjusted to the specified block. void moveOpBefore(Operation *op, Block *block, Block::iterator iterator); /// Unlink this operation from its current block and insert it right after /// `existingOp` which may be in the same or another block in the same /// function. + /// + /// If the insertion point is before the moved operation, the insertion block + /// is adjusted to the block of `existingOp`. void moveOpAfter(Operation *op, Operation *existingOp); /// Unlink this operation from its current block and insert it right after /// `iterator` in the specified block. + /// + /// If the insertion point is before the moved operation, the insertion block + /// is adjusted to the specified block. void moveOpAfter(Operation *op, Block *block, Block::iterator iterator); /// Unlink this block and insert it right before `existingBlock`. diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp index 8c44914..0155b47 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -184,6 +184,7 @@ private: return [&body, this](Value lhs, Value rhs) -> Value { Block *block = rewriter.getInsertionBlock(); Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(block); // Insert accumulator body between split block. IRMapping mapping; diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 57825d9..9dc4867 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -68,6 +68,16 @@ void Block::erase() { getParent()->getBlocks().erase(this); } +bool Block::isBeforeInBlock(iterator a, iterator b) { + if (a == b) + return false; + if (a == end()) + return false; + if (b == end()) + return true; + return a->isBeforeInBlock(&*b); +} + /// Returns 'op' if 'op' lies in this block, or otherwise finds the /// ancestor operation of 'op' that lies in this block. Returns nullptr if /// the latter fails. diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp index 0e53b43..b256137 100644 --- a/mlir/lib/IR/Dominance.cpp +++ b/mlir/lib/IR/Dominance.cpp @@ -235,20 +235,6 @@ findAncestorIteratorInRegion(Region *r, Block *b, Block::iterator it) { return std::make_pair(op->getBlock(), op->getIterator()); } -/// Given two iterators into the same block, return "true" if `a` is before `b. -/// Note: This is a variant of Operation::isBeforeInBlock that operates on -/// block iterators instead of ops. -static bool isBeforeInBlock(Block *block, Block::iterator a, - Block::iterator b) { - if (a == b) - return false; - if (a == block->end()) - return false; - if (b == block->end()) - return true; - return a->isBeforeInBlock(&*b); -} - template <bool IsPostDom> bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl( Block *aBlock, Block::iterator aIt, Block *bBlock, Block::iterator bIt, @@ -290,9 +276,9 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl( if (!hasSSADominance(aBlock)) return true; if constexpr (IsPostDom) { - return isBeforeInBlock(aBlock, bIt, aIt); + return aBlock->isBeforeInBlock(bIt, aIt); } else { - return isBeforeInBlock(aBlock, aIt, bIt); + return aBlock->isBeforeInBlock(aIt, bIt); } } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 9332f55..2cb4541 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Iterators.h" #include "mlir/IR/RegionKindInterface.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" using namespace mlir; @@ -348,14 +349,29 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest, /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. Block *RewriterBase::splitBlock(Block *block, Block::iterator before) { + Block *newBlock; + + // If the current insertion point is at or after the split point, adjust the + // insertion point to the new block. + bool moveIpToNewBlock = getBlock() == block && + !block->isBeforeInBlock(getInsertionPoint(), before); + auto adjustInsertionPoint = llvm::make_scope_exit([&]() { + if (getInsertionPoint() == block->end()) { + // If the insertion point is at the end of the block, move it to the end + // of the new block. + setInsertionPointToEnd(newBlock); + } else if (moveIpToNewBlock) { + setInsertionPoint(newBlock, getInsertionPoint()); + } + }); + // Fast path: If no listener is attached, split the block directly. if (!listener) - return block->splitBlock(before); + return newBlock = block->splitBlock(before); // `createBlock` sets the insertion point at the beginning of the new block. InsertionGuard g(*this); - Block *newBlock = - createBlock(block->getParent(), std::next(block->getIterator())); + newBlock = createBlock(block->getParent(), std::next(block->getIterator())); // If `before` points to end of the block, no ops should be moved. if (before == block->end()) @@ -413,6 +429,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block, Block *currentBlock = op->getBlock(); Block::iterator nextIterator = std::next(op->getIterator()); op->moveBefore(block, iterator); + + // If the current insertion point is before the moved operation, we may have + // to adjust the insertion block. + if (getInsertionPoint() == op->getIterator()) + setInsertionPoint(block, op->getIterator()); + if (listener) listener->notifyOperationInserted( op, /*previous=*/InsertPoint(currentBlock, nextIterator)); |