diff options
author | Tobias Gysi <gysit@google.com> | 2021-07-28 09:42:01 +0000 |
---|---|---|
committer | Tobias Gysi <gysit@google.com> | 2021-07-28 11:18:22 +0000 |
commit | ca0d244e99f4325711638359eb69f8129b41a63a (patch) | |
tree | 18790c4a7cee95f03e958417d46abbb133080015 | |
parent | 4fd42e2e803d8a532845f448fca4002ede3070f5 (diff) | |
download | llvm-ca0d244e99f4325711638359eb69f8129b41a63a.zip llvm-ca0d244e99f4325711638359eb69f8129b41a63a.tar.gz llvm-ca0d244e99f4325711638359eb69f8129b41a63a.tar.bz2 |
[mlir][linalg] Introduce a separate EraseIdentityCopyOp Pattern.
Split out an EraseIdentityCopyOp from the existing RemoveIdentityLinalgOps pattern. Introduce an additional check to ensure the pattern checks the permutation maps match. This is a preparation step to specialize RemoveIdentityLinalgOps to GenericOp only.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D105622
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 34 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/canonicalize.mlir | 14 |
3 files changed, 40 insertions, 9 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 9055f3c..3ddc7b7 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -168,6 +168,7 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { custom<CopyOpRegion>($region, ref(type($input)), ref(type($input))) }]; + let hasCanonicalizer = 1; let hasFolder = 1; let skipDefaultBuilders = 1; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 9096e37..7931f9b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -426,6 +426,31 @@ void CopyOp::getEffects( SideEffects::DefaultResource::get()); } +namespace { +/// Remove copy operations that copy data inplace. Requirements are: +/// 1) The input and output values are identical. +/// 2) The input and output permutation maps are identical. +struct EraseIdentityCopyOp : public OpRewritePattern<CopyOp> { + using OpRewritePattern<CopyOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(CopyOp copyOp, + PatternRewriter &rewriter) const override { + assert(copyOp.hasBufferSemantics()); + if (copyOp.input() == copyOp.output() && + copyOp.inputPermutation() == copyOp.outputPermutation()) { + rewriter.eraseOp(copyOp); + return success(); + } + return failure(); + } +}; +} // namespace + +void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<EraseIdentityCopyOp>(context); +} + //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// @@ -2615,15 +2640,6 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> { LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - if (auto copyOp = dyn_cast<CopyOp>(*op)) { - assert(copyOp.hasBufferSemantics()); - if (copyOp.input() == copyOp.output() && - copyOp.inputPermutation() == copyOp.outputPermutation()) { - rewriter.eraseOp(op); - return success(); - } - } - if (!isa<GenericOp>(op)) return failure(); if (!op.hasTensorSemantics()) diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index b24876c..ee7edbe 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -661,6 +661,20 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) { // ----- +// CHECK-LABEL: @self_copy_with_permutation +func @self_copy_with_permutation(%arg0 : memref<2x3x?x4xf32>) { + +// CHECK: linalg.copy + linalg.copy(%arg0, %arg0) + {inputPermutation = affine_map<(i, j, k, l) -> (j, k, i, l)>, + outputPermuation = affine_map<(i, j, k, l) -> (i, j, k, l)>} : memref<2x3x?x4xf32>, memref<2x3x?x4xf32> + +// CHECK: return + return +} + +// ----- + // CHECK-LABEL: func @fold_fill_reshape() func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = constant 0.0 : f32 |