diff options
author | MaheshRavishankar <ravishankarm@google.com> | 2021-01-26 23:21:33 -0800 |
---|---|---|
committer | MaheshRavishankar <ravishankarm@google.com> | 2021-01-26 23:22:28 -0800 |
commit | 7c15e0f64ccc79a53ed2db258f1cb58ec452a957 (patch) | |
tree | 3b1d5a0ede8c2c33ccf5bdf47dd41c6161e6925e | |
parent | 48bdd676a1d1338c10541460bf5beb69ac17e451 (diff) | |
download | llvm-7c15e0f64ccc79a53ed2db258f1cb58ec452a957.zip llvm-7c15e0f64ccc79a53ed2db258f1cb58ec452a957.tar.gz llvm-7c15e0f64ccc79a53ed2db258f1cb58ec452a957.tar.bz2 |
[mlir][Linalg] Add canonicalization for init_tensor -> subtensor op.
Differential Revision: https://reviews.llvm.org/D95305
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 29 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/canonicalize.mlir | 16 |
2 files changed, 42 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a6f3576..2982132 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -896,7 +896,29 @@ static Value getExpandedInitTensor(OpBuilder &builder, } namespace { -struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> { +/// Since `init_tensor` operation creates a tensor needed only for its shape, a +/// subtensor of this is also needed only for its shape. The result can be +/// replaced by a new init_tensor operation of the same size as the subtensor +/// op. +struct FoldInitTensorWithSubTensorOp : public OpRewritePattern<SubTensorOp> { + using OpRewritePattern<SubTensorOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorOp subtensorOp, + PatternRewriter &rewriter) const override { + if (!subtensorOp.source().getDefiningOp<linalg::InitTensorOp>()) + return failure(); + rewriter.replaceOpWithNewOp<linalg::InitTensorOp>( + subtensorOp, subtensorOp.sizes(), + llvm::to_vector<4>(llvm::map_range( + subtensorOp.static_sizes(), + [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })), + subtensorOp.getSourceType().getElementType()); + return success(); + } +}; + +struct FoldInitTensorWithTensorReshapeOp + : public OpRewritePattern<TensorReshapeOp> { using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, @@ -921,8 +943,9 @@ struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> { void InitTensorOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp, - ReplaceStaticShapeDims>(context); + results + .insert<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp, + ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index cc00b98..418d9d2 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -668,3 +668,19 @@ func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) // CHECK-LABEL: func @keep_not_noop // CHECK: %[[RESULT:.+]]:2 = linalg.generic // CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 + +// ----- + +func @fold_init_tensor_with_subtensor + (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> +{ + %0 = linalg.init_tensor[%arg0, 10, 40] : tensor<?x10x40xf32> + %1 = subtensor %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] + : tensor<?x10x40xf32> to tensor<5x?x20xf32> + return %1 : tensor<5x?x20xf32> +} +// CHECK: func @fold_init_tensor_with_subtensor +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20] +// CHECK: return %[[T0]] |