diff options
| author | Lei Zhang <antiagainst@google.com> | 2021-03-24 17:51:14 -0400 |
|---|---|---|
| committer | Lei Zhang <antiagainst@google.com> | 2021-03-24 18:17:57 -0400 |
| commit | 7f28d27cb614c47e6cf68f5deae729270d13cb08 (patch) | |
| tree | 05a92b10020715bb1f7a188f0fef3e48e7adc78b | |
| parent | f66120a3575a19d2b9b47b584698d5d950f63589 (diff) | |
| download | llvm-7f28d27cb614c47e6cf68f5deae729270d13cb08.zip llvm-7f28d27cb614c47e6cf68f5deae729270d13cb08.tar.gz llvm-7f28d27cb614c47e6cf68f5deae729270d13cb08.tar.bz2 | |
[mlir][linalg] Allow controlling folding unit dim reshapes
This commit exposes an option to the pattern
FoldWithProducerReshapeOpByExpansion to allow
folding unit dim reshapes. This gives callers
more fine-grained controls.
Differential Revision: https://reviews.llvm.org/D99114
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Passes.h | 6 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Passes.td | 6 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp | 34 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/reshape_fusion.mlir | 13 |
4 files changed, 44 insertions, 15 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index ecec2a3..18820d4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -65,7 +65,8 @@ std::unique_ptr<Pass> createLinalgDetensorizePass(); /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. -void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns); +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its /// producer (consumer) generic/indexed_generic operation by linearizing the @@ -83,7 +84,8 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns); /// Patterns for fusing linalg operation on tensors. -void populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns); +void populateLinalgTensorOpsFusionPatterns( + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index aad1117..786b9ec 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -37,6 +37,12 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; + let options = [ + Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes", + "bool", /*default=*/"false", + "Allow fusing linalg.tensor_reshape ops that performs unit " + "dimension collapsing"> + ]; let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 4b0951e..7e89a08 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -897,9 +897,14 @@ struct FoldProducerReshapeOpByLinearization /// generic/indexed_generic op, when the reshape op is collapsing /// dimensions. The dimensionality of the loop in the consumer is expanded. template <typename GenericOpTy> -struct FoldWithProducerReshapeOpByExpansion +class FoldWithProducerReshapeOpByExpansion : public OpRewritePattern<GenericOpTy> { - using OpRewritePattern<GenericOpTy>::OpRewritePattern; +public: + FoldWithProducerReshapeOpByExpansion(MLIRContext *context, + bool foldUnitDimReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern<GenericOpTy>(context, benefit), + allowFoldingUnitDimReshapes(foldUnitDimReshapes) {} LogicalResult matchAndRewrite(GenericOpTy genericOp, PatternRewriter &rewriter) const override { @@ -916,8 +921,9 @@ struct FoldWithProducerReshapeOpByExpansion if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || - isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), - reshapeOp.getReassociationMaps())) + (!allowFoldingUnitDimReshapes && + isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()))) continue; Optional<SmallVector<Value, 1>> replacementValues = @@ -930,6 +936,9 @@ struct FoldWithProducerReshapeOpByExpansion } return failure(); } + +private: + bool allowFoldingUnitDimReshapes; }; /// Pattern to fold tensor_reshape op with its producer. The corresponding index @@ -1134,7 +1143,8 @@ struct FusionOfTensorOpsPass void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - populateLinalgTensorOpsFusionPatterns(patterns); + populateLinalgTensorOpsFusionPatterns(patterns, + allowFoldingUnitDimReshapes); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -1171,20 +1181,22 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( } void mlir::populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns) { - patterns.add<FoldReshapeWithGenericOpByExpansion, - FoldWithProducerReshapeOpByExpansion<GenericOp>, + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { + patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext()); + patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>, FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>( - patterns.getContext()); + patterns.getContext(), allowFoldingUnitDimReshapes); } -void mlir::populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns) { +void mlir::populateLinalgTensorOpsFusionPatterns( + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { auto *context = patterns.getContext(); patterns .add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>, FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>( context); - populateFoldReshapeOpsByExpansionPatterns(patterns); + populateFoldReshapeOpsByExpansionPatterns(patterns, + allowFoldingUnitDimReshapes); GenericOp::getCanonicalizationPatterns(patterns, context); IndexedGenericOp::getCanonicalizationPatterns(patterns, context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index fbaf47c..d5dc176 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file -verify-each=0 | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file -verify-each=0 | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file -verify-each=0 | FileCheck %s --check-prefix=FOLDUNITDIM #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> @@ -300,7 +301,7 @@ func @reshape_as_consumer_permutation %5 = addi %3, %4 : i32 %6 = index_cast %arg2 : index to i32 %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 + linalg.yield %7 : i32 } -> tensor<6x4x210xi32> %d = linalg.tensor_reshape %c [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, @@ -531,3 +532,11 @@ func @unit_dim_reshape_expansion_full // CHECK-DAG: linalg.tensor_reshape // CHECK-DAG: linalg.init_tensor // CHECK: linalg.generic +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>) + +// FOLDUNITDIM: func @unit_dim_reshape_expansion_full +// FOLDUNITDIM: linalg.init_tensor +// FOLDUNITDIM-COUNT-2: linalg.tensor_reshape +// FOLDUNITDIM: linalg.generic +// FOLDUNITDIM-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) + |
