aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTobias Gysi <gysit@google.com>2021-07-28 09:42:01 +0000
committerTobias Gysi <gysit@google.com>2021-07-28 11:18:22 +0000
commitca0d244e99f4325711638359eb69f8129b41a63a (patch)
tree18790c4a7cee95f03e958417d46abbb133080015
parent4fd42e2e803d8a532845f448fca4002ede3070f5 (diff)
downloadllvm-ca0d244e99f4325711638359eb69f8129b41a63a.zip
llvm-ca0d244e99f4325711638359eb69f8129b41a63a.tar.gz
llvm-ca0d244e99f4325711638359eb69f8129b41a63a.tar.bz2
[mlir][linalg] Introduce a separate EraseIdentityCopyOp Pattern.
Split out an EraseIdentityCopyOp from the existing RemoveIdentityLinalgOps pattern. Introduce an additional check to ensure the pattern checks the permutation maps match. This is a preparation step to specialize RemoveIdentityLinalgOps to GenericOp only. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D105622
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td1
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp34
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir14
3 files changed, 40 insertions, 9 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 9055f3c..3ddc7b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -168,6 +168,7 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let skipDefaultBuilders = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9096e37..7931f9b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -426,6 +426,31 @@ void CopyOp::getEffects(
SideEffects::DefaultResource::get());
}
+namespace {
+/// Remove copy operations that copy data inplace. Requirements are:
+/// 1) The input and output values are identical.
+/// 2) The input and output permutation maps are identical.
+struct EraseIdentityCopyOp : public OpRewritePattern<CopyOp> {
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CopyOp copyOp,
+ PatternRewriter &rewriter) const override {
+ assert(copyOp.hasBufferSemantics());
+ if (copyOp.input() == copyOp.output() &&
+ copyOp.inputPermutation() == copyOp.outputPermutation()) {
+ rewriter.eraseOp(copyOp);
+ return success();
+ }
+ return failure();
+ }
+};
+} // namespace
+
+void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<EraseIdentityCopyOp>(context);
+}
+
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
@@ -2615,15 +2640,6 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
- if (auto copyOp = dyn_cast<CopyOp>(*op)) {
- assert(copyOp.hasBufferSemantics());
- if (copyOp.input() == copyOp.output() &&
- copyOp.inputPermutation() == copyOp.outputPermutation()) {
- rewriter.eraseOp(op);
- return success();
- }
- }
-
if (!isa<GenericOp>(op))
return failure();
if (!op.hasTensorSemantics())
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b24876c..ee7edbe 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -661,6 +661,20 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
// -----
+// CHECK-LABEL: @self_copy_with_permutation
+func @self_copy_with_permutation(%arg0 : memref<2x3x?x4xf32>) {
+
+// CHECK: linalg.copy
+ linalg.copy(%arg0, %arg0)
+ {inputPermutation = affine_map<(i, j, k, l) -> (j, k, i, l)>,
+ outputPermuation = affine_map<(i, j, k, l) -> (i, j, k, l)>} : memref<2x3x?x4xf32>, memref<2x3x?x4xf32>
+
+// CHECK: return
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @fold_fill_reshape()
func @fold_fill_reshape() -> tensor<6x4xf32> {
%zero = constant 0.0 : f32