diff options
author | MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> | 2024-08-23 13:43:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-23 13:43:33 -0700 |
commit | 4dbaef6d5ea71fb183114a82da4028960906c42b (patch) | |
tree | 6a8554aed50f60cc813878628a80dbea88e2aa2d | |
parent | a2a5508bdae7d115b6c3ace461beb7a987a44407 (diff) | |
download | llvm-4dbaef6d5ea71fb183114a82da4028960906c42b.zip llvm-4dbaef6d5ea71fb183114a82da4028960906c42b.tar.gz llvm-4dbaef6d5ea71fb183114a82da4028960906c42b.tar.bz2 |
[mlir][Linalg] Avoid doing op replacement in `linalg::dropUnitDims`. (#105749)
It is better to do the replacement in the caller. This avoids the
footgun if the caller needs the original operation. Instead return the
produced operation and replacement values.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 9 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 16 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp | 8 |
3 files changed, 25 insertions, 8 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index bee3452eb..0208f85 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -488,8 +488,13 @@ struct ControlDropUnitDims { return SmallVector<unsigned>{}; }; }; -LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, - const ControlDropUnitDims &options); +struct DropUnitDimsResult { + linalg::GenericOp resultOp; + SmallVector<Value> replacements; +}; +FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter, + GenericOp genericOp, + const ControlDropUnitDims &options); /// Fuse two `linalg.generic` operations that have a producer-consumer /// relationship captured through `fusedOperand`. The method expects diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 36f8696..88ef82f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -386,8 +386,9 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( return info; } -LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, - const ControlDropUnitDims &options) { +FailureOr<DropUnitDimsResult> +linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, + const ControlDropUnitDims &options) { SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); if (indexingMaps.empty()) return failure(); @@ -545,8 +546,7 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, resultReplacements.push_back(expandedValue); } - rewriter.replaceOp(genericOp, resultReplacements); - return success(); + return DropUnitDimsResult{replacementOp, resultReplacements}; } namespace { @@ -557,7 +557,13 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> { LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - return dropUnitDims(rewriter, genericOp, options); + FailureOr<DropUnitDimsResult> result = + dropUnitDims(rewriter, genericOp, options); + if (failed(result)) { + return failure(); + } + rewriter.replaceOp(genericOp, result->replacements); + return success(); } private: diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp index 85a6d5f..402ce15 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp @@ -25,7 +25,13 @@ LogicalResult dropOutermostUnitDims(RewriterBase &rewriter, linalg::GenericOp genericOp) { linalg::ControlDropUnitDims options; options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; }; - return linalg::dropUnitDims(rewriter, genericOp, options); + FailureOr<linalg::DropUnitDimsResult> result = + linalg::dropUnitDims(rewriter, genericOp, options); + if (failed(result)) { + return failure(); + } + rewriter.replaceOp(genericOp, result->replacements); + return success(); } struct TestLinalgDropUnitDims |