aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-01-17 11:08:59 +0100
committerGitHub <noreply@github.com>2024-01-17 11:08:59 +0100
commit5fcf907b34355980f77d7665a175b05fea7a6b7b (patch)
tree755f08efb850704d5b50ddf9d571e83e3400aef5 /mlir
parent57b50ef0174b6fdf7b554c4ae2691e2fa4f78f96 (diff)
downloadllvm-5fcf907b34355980f77d7665a175b05fea7a6b7b.zip
llvm-5fcf907b34355980f77d7665a175b05fea7a6b7b.tar.gz
llvm-5fcf907b34355980f77d7665a175b05fea7a6b7b.tar.bz2
[mlir][IR] Rename "update root" to "modify op" in rewriter API (#78260)
This commit renames 4 pattern rewriter API functions: * `updateRootInPlace` -> `modifyOpInPlace` * `startRootUpdate` -> `startOpModification` * `finalizeRootUpdate` -> `finalizeOpModification` * `cancelRootUpdate` -> `cancelOpModification` The term "root" is a misnomer. The root is the op that a rewrite pattern matches against (https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional). A rewriter must be notified of all in-place op modifications, not just in-place modifications of the root (https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old function names were confusing and have contributed to various broken rewrite patterns. Note: The new function names use the term "modify" instead of "update" for consistency with the `RewriterBase::Listener` terminology (`notifyOperationModified`).
Diffstat (limited to 'mlir')
-rw-r--r--mlir/docs/PatternRewriter.md10
-rw-r--r--mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp2
-rw-r--r--mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp4
-rw-r--r--mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp4
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp4
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h44
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h14
-rw-r--r--mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp3
-rw-r--r--mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp5
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp4
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp6
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp4
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp4
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp4
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp2
-rw-r--r--mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp8
-rw-r--r--mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp2
-rw-r--r--mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp6
-rw-r--r--mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Split.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp4
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp4
-rw-r--r--mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp2
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp28
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp10
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp20
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp6
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp14
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp14
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp2
-rw-r--r--mlir/lib/IR/PatternMatch.cpp4
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp4
-rw-r--r--mlir/lib/Transforms/Utils/CommutativityUtils.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp10
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp30
72 files changed, 216 insertions, 216 deletions
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 8fe5ef3..011cd14 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -213,15 +213,15 @@ user is determined by the specific pattern driver.
This method replaces an operation's results with a set of provided values, and
erases the operation.
-* Update an Operation in-place : `(start|cancel|finalize)RootUpdate`
+* Update an Operation in-place : `(start|cancel|finalize)OpModification`
This is a collection of methods that provide a transaction-like API for updating
the attributes, location, operands, or successors of an operation in-place
within a pattern. An in-place update transaction is started with
-`startRootUpdate`, and may either be canceled or finalized with
-`cancelRootUpdate` and `finalizeRootUpdate` respectively. A convenience wrapper,
-`updateRootInPlace`, is provided that wraps a `start` and `finalize` around a
-callback.
+`startOpModification`, and may either be canceled or finalized with
+`cancelOpModification` and `finalizeOpModification` respectively. A convenience
+wrapper, `modifyOpInPlace`, is provided that wraps a `start` and `finalize`
+around a callback.
* OpBuilder API
diff --git a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
index d438cb4..a23d042 100644
--- a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
+++ b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
@@ -24,7 +24,7 @@ public:
LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
if (op.getSymName() == "bar") {
- rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
+ rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); });
return success();
}
return failure();
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 240b9f9..ae4bd98 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index 240b9f9..ae4bd98 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 240b9f9..ae4bd98 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9b4fa65..b065d4e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -585,28 +585,30 @@ public:
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
- /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
- /// This is a minor efficiency win (it avoids creating a new operation and
- /// removing the old one) but also often allows simpler code in the client.
- virtual void startRootUpdate(Operation *op) {}
-
- /// This method is used to signal the end of a root update on the given
- /// operation. This can only be called on operations that were provided to a
- /// call to `startRootUpdate`.
- virtual void finalizeRootUpdate(Operation *op);
-
- /// This method cancels a pending root update. This can only be called on
- /// operations that were provided to a call to `startRootUpdate`.
- virtual void cancelRootUpdate(Operation *op) {}
-
- /// This method is a utility wrapper around a root update of an operation. It
- /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
- /// callable.
+ /// followed by a call to either `finalizeOpModification` or
+ /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
+ /// a new operation and removing the old one) but also often allows simpler
+ /// code in the client.
+ virtual void startOpModification(Operation *op) {}
+
+ /// This method is used to signal the end of an in-place modification of the
+ /// given operation. This can only be called on operations that were provided
+ /// to a call to `startOpModification`.
+ virtual void finalizeOpModification(Operation *op);
+
+ /// This method cancels a pending in-place modification. This can only be
+ /// called on operations that were provided to a call to
+ /// `startOpModification`.
+ virtual void cancelOpModification(Operation *op) {}
+
+ /// This method is a utility wrapper around an in-place modification of an
+ /// operation. It wraps calls to `startOpModification` and
+ /// `finalizeOpModification` around the given callable.
template <typename CallableT>
- void updateRootInPlace(Operation *root, CallableT &&callable) {
- startRootUpdate(root);
+ void modifyOpInPlace(Operation *root, CallableT &&callable) {
+ startOpModification(root);
callable();
- finalizeRootUpdate(root);
+ finalizeOpModification(root);
}
/// Find uses of `from` and replace them with `to`. It also marks every
@@ -619,7 +621,7 @@ public:
void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
- updateRootInPlace(op, [&]() { operand.set(to); });
+ modifyOpInPlace(op, [&]() { operand.set(to); });
}
}
void replaceAllUsesWith(ValueRange from, ValueRange to) {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index c5725e9..9568540 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -739,17 +739,17 @@ public:
/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op) override;
- /// PatternRewriter hook for updating the root operation in-place.
- /// Note: These methods only track updates to the top-level operation itself,
+ /// PatternRewriter hook for updating the given operation in-place.
+ /// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
/// through other more specific hooks above.
- void startRootUpdate(Operation *op) override;
+ void startOpModification(Operation *op) override;
- /// PatternRewriter hook for updating the root operation in-place.
- void finalizeRootUpdate(Operation *op) override;
+ /// PatternRewriter hook for updating the given operation in-place.
+ void finalizeOpModification(Operation *op) override;
- /// PatternRewriter hook for updating the root operation in-place.
- void cancelRootUpdate(Operation *op) override;
+ /// PatternRewriter hook for updating the given operation in-place.
+ void cancelOpModification(Operation *op) override;
/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 16214d7..bbef3b9 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -255,7 +255,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
// Step 2. Assign the op a real tile ID.
// For simplicity, we always use tile 0 (which always exists).
auto zeroTileId = rewriter.getI32IntegerAttr(0);
- rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
+ rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
VectorType tileVectorType = tileOp.getTileType();
auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 94df376..f853d5c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -918,8 +918,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
for (auto stream : streams)
streamDestroyCallBuilder.create(loc, rewriter, {stream});
- rewriter.updateRootInPlace(yieldOp,
- [&] { yieldOp->setOperands(newOperands); });
+ rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
return success();
}
diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
index 8c1a7d9..54e6bec 100644
--- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
+++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
@@ -43,14 +43,13 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> {
if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
op.getIfCond(), false);
- rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
thenBodyBuilder.clone(*op.getOperation());
rewriter.eraseOp(op);
} else {
if (constAttr.getInt())
- rewriter.updateRootInPlace(op,
- [&]() { op.getIfCondMutable().erase(0); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
else
rewriter.eraseOp(op);
}
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 44fbac1..f8485e0 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -645,13 +645,13 @@ struct PrepareTransferWriteConversion
rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
- rewriter.updateRootInPlace(xferOp, [&]() {
+ rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getVectorMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
});
if (xferOp.getMask()) {
- rewriter.updateRootInPlace(xferOp, [&]() {
+ rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getMaskMutable().assign(buffers.maskBuffer);
});
}
@@ -966,7 +966,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
loadIndices, iv);
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
- rewriter.updateRootInPlace(newXfer, [&]() {
+ rewriter.modifyOpInPlace(newXfer, [&]() {
newXfer.getMaskMutable().assign(mask);
});
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d5be2e9..c260e68 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2493,7 +2493,7 @@ FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
newYieldValuesFn(rewriter, getLoc(), newIterArgs);
assert(newInitOperands.size() == newYieldedValues.size() &&
"expected as many new yield values as new iter operands");
- rewriter.updateRootInPlace(yieldOp, [&]() {
+ rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getOperandsMutable().append(newYieldedValues);
});
}
@@ -2686,9 +2686,9 @@ struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
!llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
return failure();
- rewriter.startRootUpdate(ifOp);
+ rewriter.startOpModification(ifOp);
rewriter.eraseBlock(ifOp.getElseBlock());
- rewriter.finalizeRootUpdate(ifOp);
+ rewriter.finalizeOpModification(ifOp);
return success();
}
};
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index e5501e8..f28fb3a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -71,10 +71,10 @@ void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter,
op->getContext());
canonicalizeMapAndOperands(&map, &operands);
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
op.setMap(map);
op->setOperands(operands);
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
}
/// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 4d49efe..4acb2a8 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -218,7 +218,7 @@ struct AssignTileIDsPattern
return defaultVal;
};
auto setDiscardableIntAttr = [&](StringRef name, auto value) {
- rewriter.updateRootInPlace(tileOp, [&] {
+ rewriter.modifyOpInPlace(tileOp, [&] {
func->setDiscardableAttr(name,
rewriter.getI32IntegerAttr((unsigned)value));
});
@@ -274,10 +274,10 @@ struct AssignTileIDsPattern
setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
else
setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
- rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
+ rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
for (auto *op : dependantOps) {
if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
}
}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 92278c0..32c87c1 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -30,8 +30,8 @@ class ForwardOperands : public OpConversionPattern<OpTy> {
if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
return rewriter.notifyMatchFailure(op, "operand types already match");
- rewriter.updateRootInPlace(
- op, [&]() { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&]() { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index bf627d9..8b4bacd 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -106,8 +106,8 @@ struct RelaxScalableVectorAllocaAlignment
// Set alignment based on the defaults for SVE vectors and predicates.
unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
- rewriter.updateRootInPlace(allocaOp,
- [&] { allocaOp.setAlignment(aligment); });
+ rewriter.modifyOpInPlace(allocaOp,
+ [&] { allocaOp.setAlignment(aligment); });
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index a0bb8715..4b1dfee 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -253,7 +253,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
- rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
+ rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
}
// Insert copies of Values.
@@ -274,7 +274,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// dynamic extents. Do not update these either.
if (isa<tensor::DimOp>(use->getOwner()))
continue;
- rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
+ rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
}
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 94bc2bc..253fcf2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -895,7 +895,7 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
deallocOp.getConditions() == conditions)
return failure();
- rewriter.updateRootInPlace(deallocOp, [&]() {
+ rewriter.modifyOpInPlace(deallocOp, [&]() {
deallocOp.getMemrefsMutable().assign(memrefs);
deallocOp.getConditionsMutable().assign(conditions);
});
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 4265351..75d6519 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -42,7 +42,7 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
deallocOp.getConditions() == conditions)
return failure();
- rewriter.updateRootInPlace(deallocOp, [&]() {
+ rewriter.modifyOpInPlace(deallocOp, [&]() {
deallocOp.getMemrefsMutable().assign(memrefs);
deallocOp.getConditionsMutable().assign(conditions);
});
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 999c04e..d242d75 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -403,8 +403,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
constantTrue = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
- rewriter.updateRootInPlace(use.getOwner(),
- [&] { use.set(constantTrue); });
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&] { use.set(constantTrue); });
}
}
}
@@ -418,8 +418,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
constantFalse = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
- rewriter.updateRootInPlace(use.getOwner(),
- [&] { use.set(constantFalse); });
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&] { use.set(constantFalse); });
}
}
}
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 98ae826..fa030cb 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -86,7 +86,7 @@ struct DecomposeCallGraphTypesForFuncArgs
if (failed(typeConverter->convertTypes(functionType.getResults(),
newResultTypes)))
return failure();
- rewriter.updateRootInPlace(op, [&] {
+ rewriter.modifyOpInPlace(op, [&] {
op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
newResultTypes));
});
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index 742830e..d1f3b56 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -84,7 +84,7 @@ public:
newOperands[idx] = operands[idx];
}
}
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [newOperands, op]() { op->setOperands(newOperands); });
return success();
}
@@ -107,8 +107,8 @@ public:
ConversionPatternRewriter &rewriter) const final {
// For a return, all operands go to the results of the parent, so
// rewrite them all.
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
index 7005693..c04986c 100644
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
@@ -80,7 +80,7 @@ public:
auto newType = FunctionType::get(rewriter.getContext(),
argumentMapping.getConvertedTypes(),
funcResultMapping.getConvertedTypes());
- rewriter.updateRootInPlace(op, [&] { op.setType(newType); });
+ rewriter.modifyOpInPlace(op, [&] { op.setType(newType); });
// Update block signatures.
if (!op.isExternal()) {
@@ -105,7 +105,7 @@ public:
return failure();
// Convert operands.
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 514b3e9..30b6cd7 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2030,7 +2030,7 @@ public:
continue;
validOperands.push_back(operand);
}
- rewriter.updateRootInPlace(op, [&]() { op->setOperands(validOperands); });
+ rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
return success();
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 96a0ef5..bf24194 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -301,7 +301,7 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses(
// the variable has been optimized out.
auto undef =
rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType());
- rewriter.updateRootInPlace(*this, [&] { getValueMutable().assign(undef); });
+ rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); });
return DeletionKind::Keep;
}
@@ -394,7 +394,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
return DeletionKind::Delete;
}
- rewriter.updateRootInPlace(*this, [&]() {
+ rewriter.modifyOpInPlace(*this, [&]() {
// Rewire the indices by popping off the second index.
// Start with a single zero, then add the indices beyond the second.
SmallVector<int32_t> newIndices(1);
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index cf900ac..72f9295 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -83,8 +83,8 @@ static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
op.getAddr(), firstTypeIndices);
- rewriter.updateRootInPlace(op,
- [&]() { op.getAddrMutable().assign(properPtr); });
+ rewriter.modifyOpInPlace(op,
+ [&]() { op.getAddrMutable().assign(properPtr); });
}
template <>
@@ -111,8 +111,8 @@ LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
rewriter.setInsertionPointAfterValue(load.getResult());
BitcastOp bitcast = rewriter.create<BitcastOp>(
load->getLoc(), load.getResult().getType(), load.getResult());
- rewriter.updateRootInPlace(load,
- [&]() { load.getResult().setType(firstType); });
+ rewriter.modifyOpInPlace(load,
+ [&]() { load.getResult().setType(firstType); });
rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
bitcast);
}
@@ -141,7 +141,7 @@ LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
store, [&]() { store.getValueMutable().assign(store.getValue()); });
return success();
@@ -630,8 +630,8 @@ LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
auto bitcastOp =
rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
- rewriter.updateRootInPlace(
- store, [&] { store.getValueMutable().assign(bitcastOp); });
+ rewriter.modifyOpInPlace(store,
+ [&] { store.getValueMutable().assign(bitcastOp); });
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 139566d..f7cfe8a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -785,7 +785,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
// Replace the use in containingOp.
- rewriter.updateRootInPlace(containingOp, [&]() {
+ rewriter.modifyOpInPlace(containingOp, [&]() {
containingOp->setOperand(pUse->getOperandNumber(),
destinationTensors.front());
});
@@ -835,7 +835,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(use->getOwner());
fusedOp = rewriter.clone(*producerOp);
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
return fusedOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index d8df5d8..ff13aaf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -311,7 +311,7 @@ Value linalg::bufferizeToAllocation(
auto toTensorOp =
resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
assert(toTensorOp && "expected to_tensor op");
- rewriter.updateRootInPlace(toTensorOp, [&]() {
+ rewriter.modifyOpInPlace(toTensorOp, [&]() {
toTensorOp.setRestrict(true);
toTensorOp.setWritable(true);
});
@@ -559,11 +559,11 @@ Value linalg::bufferizeToAllocation(
// tensor is uninitialized.
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
}
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
operand->set(toTensorOp);
if (options.bufferizeDestinationOnly) {
- rewriter.updateRootInPlace(toTensorOp, [&]() {
+ rewriter.modifyOpInPlace(toTensorOp, [&]() {
toTensorOp.setRestrict(true);
toTensorOp.setWritable(true);
});
@@ -584,7 +584,7 @@ Value linalg::bufferizeToAllocation(
for (OpOperand *resultUse : resultUses) {
auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
assert(toTensorOp && "expected to_tensor op");
- rewriter.updateRootInPlace(toTensorOp, [&]() {
+ rewriter.modifyOpInPlace(toTensorOp, [&]() {
toTensorOp.setRestrict(true);
toTensorOp.setWritable(true);
});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index bf91a70..98cd044 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -104,7 +104,7 @@ struct FunctionNonEntryBlockConversion
LogicalResult
matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
Region &region = op.getFunctionBody();
SmallVector<TypeConverter::SignatureConversion, 2> conversions;
@@ -125,11 +125,11 @@ struct FunctionNonEntryBlockConversion
if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
conversions))) {
- rewriter.cancelRootUpdate(op);
+ rewriter.cancelOpModification(op);
return failure();
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 031f5c7..e4cb2f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1816,7 +1816,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
bool modifiedOutput = false;
Location loc = op.getLoc();
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
@@ -1843,10 +1843,10 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
}
}
if (!modifiedOutput) {
- rewriter.cancelRootUpdate(op);
+ rewriter.cancelOpModification(op);
return failure();
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index f28f8f0..81669a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -87,7 +87,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
}
// Turn the "in" into an "out".
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
out->set(in->get());
// The original "in" could be removed entirely here (because it will no
// longer have any uses in the payload), but we delegate this to
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index 3378eda..16ab45e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -354,7 +354,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
// Directly replace the cycle with the blockArg such that
// Deduplicate pattern can eliminate it along with unused yield.
rewriter.replaceOp(cycleOp, outputArg);
- rewriter.updateRootInPlace(genericOp, [] {});
+ rewriter.modifyOpInPlace(genericOp, [] {});
hasRemovedCycles = true;
}
@@ -404,7 +404,7 @@ struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
return failure();
// Rewrite the op.
- rewriter.updateRootInPlace(genericOp, [&]() {
+ rewriter.modifyOpInPlace(genericOp, [&]() {
for (auto [before, after] : replacements) {
BlockArgument bbArg = genericOp.getBody()->getArgument(before);
BlockArgument replacement = genericOp.getBody()->getArgument(after);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 805c9d4..b32ea8e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -854,10 +854,10 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
LLVM_DEBUG(DBGS() << "with result #"
<< numOriginalForOpResults + iterArgNumber
<< " of forOp, giving us: " << extracted << "\n");
- rewriter.startRootUpdate(extracted);
+ rewriter.startOpModification(extracted);
extracted.getSourceMutable().assign(
newForOp.getResult(numOriginalForOpResults + iterArgNumber));
- rewriter.finalizeRootUpdate(extracted);
+ rewriter.finalizeOpModification(extracted);
LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
<< "\n");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index f46ba71..a0faeb5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -60,9 +60,9 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
assert(permutationMap && "unexpected null map");
// Start a guarded inplace update.
- rewriter.startRootUpdate(genericOp);
- auto guard =
- llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
+ rewriter.startOpModification(genericOp);
+ auto guard = llvm::make_scope_exit(
+ [&]() { rewriter.finalizeOpModification(genericOp); });
// 2. Compute the interchanged indexing maps.
SmallVector<AffineMap> newIndexingMaps;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index bbe3a54..0174db4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -113,7 +113,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
// Need to pretend that the original op now takes as operands firstResults,
// otherwise tiling interface implementation will take the wrong value to
// produce data tiles.
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
unsigned numTotalOperands = op->getNumOperands();
unsigned numOutputOperands = firstResults.size();
op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7f3ab1f..ebf80e3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -722,7 +722,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// We cannot use a IRMapping here because it can replace
// different OpOperands with the same value.
Operation *clonedOp = b.clone(*op.getOperation());
- b.updateRootInPlace(clonedOp, [&]() {
+ b.modifyOpInPlace(clonedOp, [&]() {
for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
tiledDpsInitOperands)) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index dc348ea..0610f24 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1952,7 +1952,7 @@ struct PadOpVectorizationWithTransferReadPattern
if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
return failure();
- rewriter.updateRootInPlace(xferOp, [&]() {
+ rewriter.modifyOpInPlace(xferOp, [&]() {
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
xferOp->setAttr(xferOp.getInBoundsAttrName(),
rewriter.getBoolArrayAttr(inBounds));
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index be301c1..561b861 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -227,7 +227,7 @@ DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
Attribute index = getAttributeIndexFromIndexOperands(
getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
- rewriter.updateRootInPlace(*this, [&]() {
+ rewriter.modifyOpInPlace(*this, [&]() {
setMemRef(memorySlot.ptr);
getIndicesMutable().clear();
});
@@ -280,7 +280,7 @@ DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
Attribute index = getAttributeIndexFromIndexOperands(
getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
- rewriter.updateRootInPlace(*this, [&]() {
+ rewriter.modifyOpInPlace(*this, [&]() {
setMemRef(memorySlot.ptr);
getIndicesMutable().clear();
});
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 394640f..b79ab8f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -792,7 +792,7 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
if (fromType && toType) {
if (fromType.getShape() == toType.getShape() &&
fromType.getElementType() == toType.getElementType()) {
- rewriter.updateRootInPlace(copyOp, [&] {
+ rewriter.modifyOpInPlace(copyOp, [&] {
copyOp.getSourceMutable().assign(castOp.getSource());
});
modified = true;
@@ -808,7 +808,7 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
if (fromType && toType) {
if (fromType.getShape() == toType.getShape() &&
fromType.getElementType() == toType.getElementType()) {
- rewriter.updateRootInPlace(copyOp, [&] {
+ rewriter.modifyOpInPlace(copyOp, [&] {
copyOp.getTargetMutable().assign(castOp.getSource());
});
modified = true;
@@ -1366,7 +1366,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
.getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
- // updateRootInplace: lambda cannot capture structured bindings in C++17
+ // modifyOpInPlace: lambda cannot capture structured bindings in C++17
// yet.
op->replaceUsesOfWith(result, constantVal);
atLeastOneReplacement = true;
@@ -2436,7 +2436,7 @@ public:
op.getReassociationIndices());
if (newResultType == op.getResultType()) {
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
} else {
Value newOp = rewriter.create<CollapseShapeOp>(
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 101e099..8047c60 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -797,7 +797,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
if (!viewLikeOp)
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
- rewriter.updateRootInPlace(extractOp, [&]() {
+ rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
});
return success();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 03765e9..10ba508 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -154,7 +154,7 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
for (OpOperand &operand : user->getOpOperands()) {
if ([[maybe_unused]] auto castOp =
operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
user, [&]() { operand.set(conversion->getOperand(0)); });
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 397bd585..bc0dd03 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -79,9 +79,9 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
// TODO: can we use an early_inc iterator?
for (OpOperand *operand : operandsToReplace) {
Operation *op = operand->getOwner();
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
operand->set(val);
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
}
// Perform late op erasure.
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
index 8bfb4be..8163f42 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -54,7 +54,7 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
"for nvgpu.mma.sync on f32 datatype");
if (precision == MmaSyncF32Lowering::TF32) {
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
}
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bf3264b..8698c00 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -359,7 +359,7 @@ struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
if (!matchPattern(ifCond, m_Constant(&constAttr)))
return failure();
if (constAttr.getInt())
- rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
else
rewriter.eraseOp(op);
@@ -398,7 +398,7 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
if (!matchPattern(ifCond, m_Constant(&constAttr)))
return failure();
if (constAttr.getInt())
- rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
else
replaceOpWithRegion(rewriter, op, op.getRegion());
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index cdc0b6f..45cc747 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -552,7 +552,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
newYieldValuesFn(rewriter, getLoc(), newIterArgs);
assert(newInitOperands.size() == newYieldedValues.size() &&
"expected as many new yield values as new iter operands");
- rewriter.updateRootInPlace(yieldOp, [&]() {
+ rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getResultsMutable().append(newYieldedValues);
});
}
@@ -1444,7 +1444,7 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
Value sharedOut =
forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
->get();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
return success();
}
@@ -1464,7 +1464,7 @@ public:
failed(foldDynamicIndexList(mixedStep)))
return failure();
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
@@ -1556,7 +1556,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
for (const auto &namedAttr : op->getAttrs()) {
if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
continue;
- rewriter.updateRootInPlace(newOp, [&]() {
+ rewriter.modifyOpInPlace(newOp, [&]() {
newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
});
}
@@ -2023,8 +2023,8 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
[&](OpResult result) {
return yieldOp.getOperand(result.getResultNumber());
});
- rewriter.updateRootInPlace(yieldOp,
- [&]() { yieldOp->setOperands(usedOperands); });
+ rewriter.modifyOpInPlace(yieldOp,
+ [&]() { yieldOp->setOperands(usedOperands); });
}
LogicalResult matchAndRewrite(IfOp op,
@@ -2189,8 +2189,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
constantTrue = rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
- rewriter.updateRootInPlace(use.getOwner(),
- [&]() { use.set(constantTrue); });
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&]() { use.set(constantTrue); });
} else if (op.getElseRegion().isAncestor(
use.getOwner()->getParentRegion())) {
changed = true;
@@ -2199,8 +2199,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
constantFalse = rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
- rewriter.updateRootInPlace(use.getOwner(),
- [&]() { use.set(constantFalse); });
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&]() { use.set(constantFalse); });
}
}
@@ -2383,14 +2383,14 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
llvm::make_early_inc_range(std::get<0>(it).getUses())) {
if (nextThen && nextThen->getParent()->isAncestor(
use.getOwner()->getParentRegion())) {
- rewriter.startRootUpdate(use.getOwner());
+ rewriter.startOpModification(use.getOwner());
use.set(std::get<1>(it));
- rewriter.finalizeRootUpdate(use.getOwner());
+ rewriter.finalizeOpModification(use.getOwner());
} else if (nextElse && nextElse->getParent()->isAncestor(
use.getOwner()->getParentRegion())) {
- rewriter.startRootUpdate(use.getOwner());
+ rewriter.startOpModification(use.getOwner());
use.set(std::get<2>(it));
- rewriter.finalizeRootUpdate(use.getOwner());
+ rewriter.finalizeOpModification(use.getOwner());
}
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index dc3c46b..90f935d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -688,7 +688,7 @@ struct ForOpInterface
yieldValues.push_back(*alloc);
}
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
return success();
}
@@ -928,7 +928,7 @@ struct WhileOpInterface
return failure();
beforeYieldValues.push_back(*alloc);
}
- rewriter.updateRootInPlace(conditionOp, [&]() {
+ rewriter.modifyOpInPlace(conditionOp, [&]() {
conditionOp.getArgsMutable().assign(beforeYieldValues);
});
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index 7b6b07e..cda561b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -89,8 +89,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
SmallVector<Value> yieldOperands = yieldOp.getOperands();
yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
- rewriter.updateRootInPlace(
- yieldOp, [&]() { yieldOp->setOperands(yieldOperands); });
+ rewriter.modifyOpInPlace(yieldOp,
+ [&]() { yieldOp->setOperands(yieldOperands); });
}
// We cannot do a direct replacement of the forOp since the while op returns
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index eee0791..c6d024c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -99,7 +99,7 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
return failure();
Value initArg = forOp.getTiedLoopInit(blockArg)->get();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
@@ -141,7 +141,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
unsigned resultNumber = opResult.getResultNumber();
if (!isShapePreserving(forOp, resultNumber))
return failure();
- rewriter.updateRootInPlace(dimOp, [&]() {
+ rewriter.modifyOpInPlace(dimOp, [&]() {
dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
});
return success();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 3422135..a5bff0a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -160,8 +160,8 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
partialIteration.getInitArgsMutable().assign(forOp->getResults());
// Set new upper loop bound.
- b.updateRootInPlace(
- forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
+ b.modifyOpInPlace(forOp,
+ [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
return success();
}
@@ -239,7 +239,7 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
firstIteration = cast<ForOp>(b.clone(*forOp.getOperation(), map));
// Update main loop with new lower bound.
- b.updateRootInPlace(forOp, [&]() {
+ b.modifyOpInPlace(forOp, [&]() {
forOp.getInitArgsMutable().assign(firstIteration->getResults());
forOp.getLowerBoundMutable().assign(splitBound);
});
@@ -286,11 +286,11 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
}
// Apply label, so that the same loop is not rewritten a second time.
- rewriter.updateRootInPlace(partialIteration, [&]() {
+ rewriter.modifyOpInPlace(partialIteration, [&]() {
partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
});
- rewriter.updateRootInPlace(forOp, [&]() {
+ rewriter.modifyOpInPlace(forOp, [&]() {
forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
});
return success();
diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
index 8c2c544..5aa35e7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
@@ -111,7 +111,7 @@ public:
return failure();
// Convert operands.
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
@@ -131,7 +131,7 @@ public:
return failure();
// Convert operands.
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 7932c38..e2cc5b4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -241,7 +241,7 @@ public:
for (Value operand : adaptor.getOperands())
unpackUnrealizedConversionCast(operand, unpackedYield);
- rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); });
+ rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
return success();
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 38e0625..5c9b528 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -692,7 +692,7 @@ void mlir::scf::yieldReplacementForFusedProducer(
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
unsigned resultNumber = fusableProducer.getResultNumber();
- rewriter.updateRootInPlace(tiledDestStyleOp, [&]() {
+ rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index c22cb67..354db64 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -91,8 +91,8 @@ public:
LogicalResult
matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 9f2755da0..6150b5e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -261,7 +261,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
return failure();
// Creates a new function with the update signature.
- rewriter.updateRootInPlace(funcOp, [&] {
+ rewriter.modifyOpInPlace(funcOp, [&] {
funcOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), std::nullopt));
});
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index c8e77f7..d33eb9d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -29,7 +29,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
// Clones the original operation but changing the output to an unordered COO.
Operation *cloned = rewriter.clone(*op.getOperation());
- rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
+ rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
cloned->getOpResult(0).setType(srcCOOTp);
});
Value srcCOO = cloned->getOpResult(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 50713be..a0f7b55 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -389,14 +389,14 @@ public:
auto stt = tryGetSparseTensorType(res);
auto [idxMap, itTp] = *transMap;
- rewriter.startRootUpdate(linalgOp);
+ rewriter.startOpModification(linalgOp);
linalgOp.setIndexingMapsAttr(idxMap);
linalgOp.setIteratorTypesAttr(itTp);
// Use demapped arguments.
linalgOp.getInputsMutable().assign(adaptor.getInputs());
linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
res.setType(adaptor.getOutputs()[0].getType());
- rewriter.finalizeRootUpdate(linalgOp);
+ rewriter.finalizeOpModification(linalgOp);
rewriter.setInsertionPointAfter(linalgOp);
if (stt && stt->hasEncoding()) {
@@ -458,7 +458,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
}
// Marks the GenericOp to avoid recursive matching.
- rewriter.updateRootInPlace(linalgOp, [&]() {
+ rewriter.modifyOpInPlace(linalgOp, [&]() {
linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
});
@@ -482,10 +482,10 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
for (AffineMap &idxMap : idxMaps)
idxMap = idxMap.compose(order); // sorted loop -> lvl map
- rewriter.startRootUpdate(linalgOp);
+ rewriter.startOpModification(linalgOp);
linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
- rewriter.finalizeRootUpdate(linalgOp);
+ rewriter.finalizeOpModification(linalgOp);
return success();
}
@@ -570,7 +570,7 @@ private:
rewriter.setInsertionPoint(linalgOp);
RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
- rewriter.updateRootInPlace(linalgOp, [&]() {
+ rewriter.modifyOpInPlace(linalgOp, [&]() {
linalgOp->setOperand(t->getOperandNumber(), dst);
});
return success();
@@ -623,10 +623,10 @@ struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
}
assert(dynSz.empty()); // should have consumed all.
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
op->setOperands(dynLvlSzs);
op.getResult().setType(stt.getDemappedType());
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
rewriter.setInsertionPointAfter(op);
Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
@@ -676,7 +676,7 @@ struct ForeachOpDemapper
auto srcStt = getSparseTensorType(op.getTensor());
SmallVector<Type> prevRetTps(op.getResultTypes());
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
op.getTensorMutable().assign(adaptor.getTensor());
op.getInitArgsMutable().assign(adaptor.getInitArgs());
// Update results' types.
@@ -731,7 +731,7 @@ struct ForeachOpDemapper
rewriter.eraseOp(yield);
}
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
rewriter.setInsertionPointAfter(op);
SmallVector<Value> outs =
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fa97e405..b1b8b76 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -329,7 +329,7 @@ public:
.getCopy();
AllocTensorOp a =
op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
- rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
+ rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
}
// Replace consumer with fused operation. Old producer
// and consumer ops will be removed by DCE.
@@ -366,7 +366,7 @@ public:
if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
if (Operation *def = op.getSource().getDefiningOp()) {
if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
- rewriter.updateRootInPlace(def, [&]() {
+ rewriter.modifyOpInPlace(def, [&]() {
def->getResult(0).setType(op->getResultTypes()[0]);
});
rewriter.replaceOp(op, def->getResult(0));
@@ -804,7 +804,7 @@ public:
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
- rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
+ rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
return success();
}
if (encDst) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 7710a44..3a487a3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -545,7 +545,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
rewriter.setInsertionPointToStart(forOpNew.getBody());
} else {
- rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
+ rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
rewriter.setInsertionPoint(yield);
}
vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5834426..fec23d2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -583,7 +583,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
if (def->getBlock() == block) {
rewriter.setInsertionPoint(def);
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
- rewriter.updateRootInPlace(def, [&]() {
+ rewriter.modifyOpInPlace(def, [&]() {
def->setOperand(
i, relinkBranch(env, rewriter, block, def->getOperand(i)));
});
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 80dad06..3d8cc52 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -1416,7 +1416,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
Operation *newRed = rewriter.clone(*redExp);
// Replaces arguments of the reduction expression by using the block
// arguments from scf.reduce.
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
// Erases the out-dated reduction expression.
rewriter.eraseOp(redExp);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 816e6ba..b2fe580 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -819,7 +819,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
auto resultIndex = source.cast<OpResult>().getResultNumber();
auto initOperand = destOp.getDpsInitOperand(resultIndex);
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
return success();
}
@@ -1752,7 +1752,7 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
srcType, collapseShapeOp.getReassociationMaps());
if (newResultType == collapseShapeOp.getResultType()) {
- rewriter.updateRootInPlace(collapseShapeOp, [&]() {
+ rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
collapseShapeOp.getSrcMutable().assign(castOp.getSource());
});
} else {
@@ -2930,7 +2930,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
padTensorOp.getResultType().getShape());
if (newResultType == padTensorOp.getResultType()) {
- rewriter.updateRootInPlace(padTensorOp, [&]() {
+ rewriter.modifyOpInPlace(padTensorOp, [&]() {
padTensorOp.getSourceMutable().assign(castOp.getSource());
});
} else {
@@ -3994,9 +3994,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Fold optional PaddingValue operand away if padding is not needed.
if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
- rewriter.startRootUpdate(packOp);
+ rewriter.startOpModification(packOp);
packOp.getPaddingValueMutable().clear();
- rewriter.finalizeRootUpdate(packOp);
+ rewriter.finalizeOpModification(packOp);
return success();
}
return failure();
@@ -4166,8 +4166,8 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
auto destValue = unPackOp.getDest().cast<OpResult>();
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
- rewriter.updateRootInPlace(
- unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
+ rewriter.modifyOpInPlace(unPackOp,
+ [&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
return failure();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 3f683f7..4c50aae 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -68,7 +68,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
return failure();
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
op.getOperation()->setOperands(
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
});
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f257728..749eb56 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4416,7 +4416,7 @@ public:
writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueWAW(writeOp, defWrite)) {
- rewriter.updateRootInPlace(writeToModify, [&]() {
+ rewriter.modifyOpInPlace(writeToModify, [&]() {
writeToModify.getSourceMutable().assign(defWrite.getSource());
});
return success();
@@ -4533,7 +4533,7 @@ public:
transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
transferOp.getIndices(), transferOp.getPermutationMapAttr(),
rewriter.getBoolArrayAttr(newInBounds));
- rewriter.updateRootInPlace(insertOp, [&]() {
+ rewriter.modifyOpInPlace(insertOp, [&]() {
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
});
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 5782ee1..1caec5b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -225,7 +225,7 @@ struct MaskOpInterface
newReturnValues[it.index()] = it.value();
}
}
- rewriter.updateRootInPlace(yieldOp, [&]() {
+ rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getOperandsMutable().assign(newYieldedValues);
});
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 9d5ad20..620ceee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -182,7 +182,7 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
auto yield =
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
return newWarpOp;
}
@@ -724,7 +724,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
- rewriter.startRootUpdate(warpOp);
+ rewriter.startOpModification(warpOp);
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
auto newAttr = DenseElementsAttr::get(
@@ -733,7 +733,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
- rewriter.finalizeRootUpdate(warpOp);
+ rewriter.finalizeOpModification(warpOp);
return success();
}
};
@@ -1017,9 +1017,9 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
- rewriter.startRootUpdate(warpOp);
+ rewriter.startOpModification(warpOp);
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
- rewriter.finalizeRootUpdate(warpOp);
+ rewriter.finalizeOpModification(warpOp);
return success();
}
};
@@ -1159,7 +1159,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
- rewriter.startRootUpdate(warpOp);
+ rewriter.startOpModification(warpOp);
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
@@ -1179,7 +1179,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto newMask =
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
- rewriter.finalizeRootUpdate(warpOp);
+ rewriter.finalizeOpModification(warpOp);
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index ea33453..f1a2716 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -525,7 +525,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
auto inBoundsAttr = b.getBoolArrayAttr(bools);
if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
- b.updateRootInPlace(xferOp, [&]() {
+ b.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
});
return success();
@@ -598,7 +598,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
- b.updateRootInPlace(xferOp, [&]() {
+ b.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
});
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 661674d..bd02c07 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1050,7 +1050,7 @@ public:
mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
}
- rewriter.updateRootInPlace(xferOp, [&]() {
+ rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getMaskMutable().assign(mask);
xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
});
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5e788cd..73f232f 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -263,7 +263,7 @@ void RewriterBase::eraseBlock(Block *block) {
block->erase();
}
-void RewriterBase::finalizeRootUpdate(Operation *op) {
+void RewriterBase::finalizeOpModification(Operation *op) {
// Notify the listener that the operation was modified.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationModified(op);
@@ -276,7 +276,7 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
if (functor(operand))
- updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); });
+ modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
}
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 26a7ea5..f3a973d 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -506,7 +506,7 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
if (info.mergePoints.contains(blockOperand.get())) {
if (!job.reachingDef)
job.reachingDef = getLazyDefaultValue();
- rewriter.updateRootInPlace(terminator, [&]() {
+ rewriter.modifyOpInPlace(terminator, [&]() {
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
.append(job.reachingDef);
});
@@ -596,7 +596,7 @@ void MemorySlotPromoter::promoteSlot() {
assert(succOperands.size() == mergePoint->getNumArguments() ||
succOperands.size() + 1 == mergePoint->getNumArguments());
if (succOperands.size() + 1 == mergePoint->getNumArguments())
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
user, [&]() { succOperands.append(getLazyDefaultValue()); });
}
}
diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
index 6034366..5ba6e47 100644
--- a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
+++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
@@ -304,7 +304,7 @@ public:
sortedOperands.push_back(commOperand->operand);
if (sortedOperands == operands)
return failure();
- rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); });
+ rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
return success();
}
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 85433d0..ef6a494 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1614,15 +1614,15 @@ void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
impl->createdOps.push_back(op);
}
-void ConversionPatternRewriter::startRootUpdate(Operation *op) {
+void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
impl->rootUpdates.emplace_back(op);
}
-void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
- PatternRewriter::finalizeRootUpdate(op);
+void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
+ PatternRewriter::finalizeOpModification(op);
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
@@ -1631,7 +1631,7 @@ void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
#endif
}
-void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
+void ConversionPatternRewriter::cancelOpModification(Operation *op) {
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
@@ -3115,7 +3115,7 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
auto newType = FunctionType::get(rewriter.getContext(),
result.getConvertedTypes(), newResults);
- rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
+ rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2e3bc76..d1ac5e8 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -160,7 +160,7 @@ struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
int64_t val = intAttr.getInt();
if (val >= MaxVal)
return failure();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
return success();
}
@@ -175,7 +175,7 @@ struct MakeOpEligible : public RewritePattern {
PatternRewriter &rewriter) const override {
if (op->hasAttr("eligible"))
return failure();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
return success();
}
@@ -195,7 +195,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
return failure();
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
// op is modified.
- rewriter.updateRootInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
+ rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
return success();
}
};
@@ -327,7 +327,7 @@ private:
Operation *newOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
op->getOperands(), op->getResultTypes());
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
newOp->setAttr("skip", rewriter.getBoolAttr(true));
@@ -415,8 +415,8 @@ private:
PatternRewriter &rewriter) const override {
if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
return failure();
- rewriter.updateRootInPlace(
- op, [&]() { op->setSuccessor(op->getBlock(), 0); });
+ rewriter.modifyOpInPlace(op,
+ [&]() { op->setSuccessor(op->getBlock(), 0); });
return success();
}
};
@@ -650,7 +650,7 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
illegalOp->getResult(0));
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
@@ -667,7 +667,7 @@ struct TestUndoBlockErase : public ConversionPattern {
rewriter.setInsertionPointToStart(secondBlock);
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
rewriter.eraseBlock(secondBlock);
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
@@ -827,7 +827,7 @@ struct TestBoundedRecursiveRewrite
LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
PatternRewriter &rewriter) const final {
// Decrement the depth of the op in-place.
- rewriter.updateRootInPlace(op, [&] {
+ rewriter.modifyOpInPlace(op, [&] {
op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
});
return success();
@@ -1333,7 +1333,7 @@ struct TestTestSignatureConversionNoConverter
if (failed(
converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
return failure();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&] { rewriter.applySignatureConversion(&region, result); });
return success();
}
@@ -1350,8 +1350,8 @@ struct TestTypeConsumerForward
LogicalResult
matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
@@ -1567,7 +1567,7 @@ struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
SmallVector<Value, 2> replacements(succOperands);
rewriter.eraseOp(branchOp);
rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
@@ -1588,7 +1588,7 @@ struct TestUndoBlocksMerge : public ConversionPattern {
SmallVector<Value, 2> replacements(succOperands);
rewriter.eraseOp(branchOp);
rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
@@ -1613,7 +1613,7 @@ struct TestMergeSingleBlockOps
rewriter.inlineBlockBefore(&innerBlock, op);
rewriter.eraseOp(innerTerminator);
rewriter.eraseOp(op);
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};