aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com>2024-08-23 13:43:33 -0700
committerGitHub <noreply@github.com>2024-08-23 13:43:33 -0700
commit4dbaef6d5ea71fb183114a82da4028960906c42b (patch)
tree6a8554aed50f60cc813878628a80dbea88e2aa2d
parenta2a5508bdae7d115b6c3ace461beb7a987a44407 (diff)
downloadllvm-4dbaef6d5ea71fb183114a82da4028960906c42b.zip
llvm-4dbaef6d5ea71fb183114a82da4028960906c42b.tar.gz
llvm-4dbaef6d5ea71fb183114a82da4028960906c42b.tar.bz2
[mlir][Linalg] Avoid doing op replacement in `linalg::dropUnitDims`. (#105749)
It is better to do the replacement in the caller. This avoids the footgun if the caller needs the original operation. Instead return the produced operation and replacement values. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h9
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp16
-rw-r--r--mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp8
3 files changed, 25 insertions, 8 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bee3452eb..0208f85 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -488,8 +488,13 @@ struct ControlDropUnitDims {
return SmallVector<unsigned>{};
};
};
-LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
- const ControlDropUnitDims &options);
+struct DropUnitDimsResult {
+ linalg::GenericOp resultOp;
+ SmallVector<Value> replacements;
+};
+FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
+ GenericOp genericOp,
+ const ControlDropUnitDims &options);
/// Fuse two `linalg.generic` operations that have a producer-consumer
/// relationship captured through `fusedOperand`. The method expects
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 36f8696..88ef82f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -386,8 +386,9 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
return info;
}
-LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
- const ControlDropUnitDims &options) {
+FailureOr<DropUnitDimsResult>
+linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+ const ControlDropUnitDims &options) {
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
if (indexingMaps.empty())
return failure();
@@ -545,8 +546,7 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
resultReplacements.push_back(expandedValue);
}
- rewriter.replaceOp(genericOp, resultReplacements);
- return success();
+ return DropUnitDimsResult{replacementOp, resultReplacements};
}
namespace {
@@ -557,7 +557,13 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- return dropUnitDims(rewriter, genericOp, options);
+ FailureOr<DropUnitDimsResult> result =
+ dropUnitDims(rewriter, genericOp, options);
+ if (failed(result)) {
+ return failure();
+ }
+ rewriter.replaceOp(genericOp, result->replacements);
+ return success();
}
private:
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
index 85a6d5f..402ce15 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
@@ -25,7 +25,13 @@ LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
linalg::GenericOp genericOp) {
linalg::ControlDropUnitDims options;
options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
- return linalg::dropUnitDims(rewriter, genericOp, options);
+ FailureOr<linalg::DropUnitDimsResult> result =
+ linalg::dropUnitDims(rewriter, genericOp, options);
+ if (failed(result)) {
+ return failure();
+ }
+ rewriter.replaceOp(genericOp, result->replacements);
+ return success();
}
struct TestLinalgDropUnitDims