aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaheshRavishankar <ravishankarm@google.com>2021-01-26 23:21:33 -0800
committerMaheshRavishankar <ravishankarm@google.com>2021-01-26 23:22:28 -0800
commit7c15e0f64ccc79a53ed2db258f1cb58ec452a957 (patch)
tree3b1d5a0ede8c2c33ccf5bdf47dd41c6161e6925e
parent48bdd676a1d1338c10541460bf5beb69ac17e451 (diff)
downloadllvm-7c15e0f64ccc79a53ed2db258f1cb58ec452a957.zip
llvm-7c15e0f64ccc79a53ed2db258f1cb58ec452a957.tar.gz
llvm-7c15e0f64ccc79a53ed2db258f1cb58ec452a957.tar.bz2
[mlir][Linalg] Add canonicalization for init_tensor -> subtensor op.
Differential Revision: https://reviews.llvm.org/D95305
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp29
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir16
2 files changed, 42 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a6f3576..2982132 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -896,7 +896,29 @@ static Value getExpandedInitTensor(OpBuilder &builder,
}
namespace {
-struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
+/// Since `init_tensor` operation creates a tensor needed only for its shape, a
+/// subtensor of this is also needed only for its shape. The result can be
+/// replaced by a new init_tensor operation of the same size as the subtensor
+/// op.
+struct FoldInitTensorWithSubTensorOp : public OpRewritePattern<SubTensorOp> {
+ using OpRewritePattern<SubTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubTensorOp subtensorOp,
+ PatternRewriter &rewriter) const override {
+ if (!subtensorOp.source().getDefiningOp<linalg::InitTensorOp>())
+ return failure();
+ rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
+ subtensorOp, subtensorOp.sizes(),
+ llvm::to_vector<4>(llvm::map_range(
+ subtensorOp.static_sizes(),
+ [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
+ subtensorOp.getSourceType().getElementType());
+ return success();
+ }
+};
+
+struct FoldInitTensorWithTensorReshapeOp
+ : public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
@@ -921,8 +943,9 @@ struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
void InitTensorOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
- ReplaceStaticShapeDims>(context);
+ results
+ .insert<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
+ ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index cc00b98..418d9d2 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -668,3 +668,19 @@ func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
// CHECK-LABEL: func @keep_not_noop
// CHECK: %[[RESULT:.+]]:2 = linalg.generic
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @fold_init_tensor_with_subtensor
+ (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
+{
+ %0 = linalg.init_tensor[%arg0, 10, 40] : tensor<?x10x40xf32>
+ %1 = subtensor %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
+ : tensor<?x10x40xf32> to tensor<5x?x20xf32>
+ return %1 : tensor<5x?x20xf32>
+}
+// CHECK: func @fold_init_tensor_with_subtensor
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20]
+// CHECK: return %[[T0]]