diff options
author | Ryan Holt <ryanholt@mathworks.com> | 2024-05-30 10:41:29 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-30 10:41:29 -0400 |
commit | 1159e7645b7f345e662759d763b3e6fcde62d005 (patch) | |
tree | 3be1d3e55976d575eb3afdef9b91f8d7ca006e44 | |
parent | adc4e45f2ecce13cf4ed9b4ab119492342b86faf (diff) | |
download | llvm-1159e7645b7f345e662759d763b3e6fcde62d005.zip llvm-1159e7645b7f345e662759d763b3e6fcde62d005.tar.gz llvm-1159e7645b7f345e662759d763b3e6fcde62d005.tar.bz2 |
[mlir][linalg] Add folder for transpose(transpose) -> transpose (#93606)
Back to back `linalg.transpose` can be rewritten to a single transpose
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 28 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/canonicalize.mlir | 45 |
3 files changed, 74 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 5ee363e..ac61117 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [ }]; let hasFolder = 1; + let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 0b403e2..b79afeb 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1872,6 +1872,34 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor, return failure(); } +/// Fold transpose with transpose. +struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> { + using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>(); + if (!defTransposeOp) + return failure(); + ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation(); + ArrayRef<int64_t> perms = transposeOp.getPermutation(); + SmallVector<int64_t> foldedPerms; + foldedPerms.reserve(perms.size()); + for (int64_t perm : perms) + foldedPerms.push_back(defPerms[perm]); + + rewriter.replaceOpWithNewOp<TransposeOp>( + transposeOp, defTransposeOp.getInput(), transposeOp.getInit(), + foldedPerms); + return success(); + } +}; + +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<FoldTransposeWithTranspose>(context); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 19cea6c..928030a 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1051,3 +1051,48 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>, // CHECK-NOT: linalg.transpose // CHECK: return %[[INPUT]] : tensor<16x32x64xf32> +// ----- + +func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>, + %init1: tensor<4x3x5xf32>, + %init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> { + // CHECK-LABEL: @transpose_transpose_cancel + // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32> + // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32> + // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<5x4x3xf32> + // CHECK-NOT: linalg.transpose + // CHECK: return %[[INPUT]] : tensor<5x4x3xf32> + %transpose1 = linalg.transpose + ins(%input:tensor<5x4x3xf32>) + outs(%init1:tensor<4x3x5xf32>) + permutation = [1, 2, 0] + %transpose2 = linalg.transpose + ins(%transpose1:tensor<4x3x5xf32>) + outs(%init2:tensor<5x4x3xf32>) + permutation = [2, 0, 1] + func.return %transpose2 : tensor<5x4x3xf32> +} + +// ----- + +func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>, + %init1: tensor<4x3x5xf32>, + %init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { + // CHECK-LABEL: @transpose_transpose_fold + // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32> + // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32> + // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<3x4x5xf32> + // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<5x4x3xf32>) outs(%[[INIT2]] : tensor<3x4x5xf32>) permutation = [2, 1, 0] + // CHECK-NOT: linalg.transpose + // CHECK: return %[[TRANSPOSE]] : tensor<3x4x5xf32> + %transpose1 = linalg.transpose + ins(%input:tensor<5x4x3xf32>) + outs(%init1:tensor<4x3x5xf32>) + permutation = [1, 2, 0] + %transpose2 = linalg.transpose + ins(%transpose1:tensor<4x3x5xf32>) + outs(%init2:tensor<3x4x5xf32>) + permutation = [1, 0, 2] + func.return %transpose2 : tensor<3x4x5xf32> +} + |