aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRyan Holt <ryanholt@mathworks.com>2024-05-30 10:41:29 -0400
committerGitHub <noreply@github.com>2024-05-30 10:41:29 -0400
commit1159e7645b7f345e662759d763b3e6fcde62d005 (patch)
tree3be1d3e55976d575eb3afdef9b91f8d7ca006e44
parentadc4e45f2ecce13cf4ed9b4ab119492342b86faf (diff)
downloadllvm-1159e7645b7f345e662759d763b3e6fcde62d005.zip
llvm-1159e7645b7f345e662759d763b3e6fcde62d005.tar.gz
llvm-1159e7645b7f345e662759d763b3e6fcde62d005.tar.bz2
[mlir][linalg] Add folder for transpose(transpose) -> transpose (#93606)
Back to back `linalg.transpose` can be rewritten to a single transpose
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td1
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp28
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir45
3 files changed, 74 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5ee363e..ac61117 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0b403e2..b79afeb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1872,6 +1872,34 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
return failure();
}
+/// Fold transpose with transpose.
+struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
+ if (!defTransposeOp)
+ return failure();
+ ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
+ ArrayRef<int64_t> perms = transposeOp.getPermutation();
+ SmallVector<int64_t> foldedPerms;
+ foldedPerms.reserve(perms.size());
+ for (int64_t perm : perms)
+ foldedPerms.push_back(defPerms[perm]);
+
+ rewriter.replaceOpWithNewOp<TransposeOp>(
+ transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
+ foldedPerms);
+ return success();
+ }
+};
+
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldTransposeWithTranspose>(context);
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 19cea6c..928030a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1051,3 +1051,48 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
// CHECK-NOT: linalg.transpose
// CHECK: return %[[INPUT]] : tensor<16x32x64xf32>
+// -----
+
+func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>,
+ %init1: tensor<4x3x5xf32>,
+ %init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {
+ // CHECK-LABEL: @transpose_transpose_cancel
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
+ // CHECK-NOT: linalg.transpose
+ // CHECK: return %[[INPUT]] : tensor<5x4x3xf32>
+ %transpose1 = linalg.transpose
+ ins(%input:tensor<5x4x3xf32>)
+ outs(%init1:tensor<4x3x5xf32>)
+ permutation = [1, 2, 0]
+ %transpose2 = linalg.transpose
+ ins(%transpose1:tensor<4x3x5xf32>)
+ outs(%init2:tensor<5x4x3xf32>)
+ permutation = [2, 0, 1]
+ func.return %transpose2 : tensor<5x4x3xf32>
+}
+
+// -----
+
+func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
+ %init1: tensor<4x3x5xf32>,
+ %init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+ // CHECK-LABEL: @transpose_transpose_fold
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<3x4x5xf32>
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<5x4x3xf32>) outs(%[[INIT2]] : tensor<3x4x5xf32>) permutation = [2, 1, 0]
+ // CHECK-NOT: linalg.transpose
+ // CHECK: return %[[TRANSPOSE]] : tensor<3x4x5xf32>
+ %transpose1 = linalg.transpose
+ ins(%input:tensor<5x4x3xf32>)
+ outs(%init1:tensor<4x3x5xf32>)
+ permutation = [1, 2, 0]
+ %transpose2 = linalg.transpose
+ ins(%transpose1:tensor<4x3x5xf32>)
+ outs(%init2:tensor<3x4x5xf32>)
+ permutation = [1, 0, 2]
+ func.return %transpose2 : tensor<3x4x5xf32>
+}
+