aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2021-03-24 17:51:14 -0400
committerLei Zhang <antiagainst@google.com>2021-03-24 18:17:57 -0400
commit7f28d27cb614c47e6cf68f5deae729270d13cb08 (patch)
tree05a92b10020715bb1f7a188f0fef3e48e7adc78b
parentf66120a3575a19d2b9b47b584698d5d950f63589 (diff)
downloadllvm-7f28d27cb614c47e6cf68f5deae729270d13cb08.zip
llvm-7f28d27cb614c47e6cf68f5deae729270d13cb08.tar.gz
llvm-7f28d27cb614c47e6cf68f5deae729270d13cb08.tar.bz2
[mlir][linalg] Allow controlling folding unit dim reshapes
This commit exposes an option to the pattern FoldWithProducerReshapeOpByExpansion to allow folding unit dim reshapes. This gives callers more fine-grained controls. Differential Revision: https://reviews.llvm.org/D99114
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Passes.h6
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Passes.td6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp34
-rw-r--r--mlir/test/Dialect/Linalg/reshape_fusion.mlir13
4 files changed, 44 insertions, 15 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index ecec2a3..18820d4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -65,7 +65,8 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
-void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns);
+void populateFoldReshapeOpsByExpansionPatterns(
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
@@ -83,7 +84,8 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns);
+void populateLinalgTensorOpsFusionPatterns(
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index aad1117..786b9ec 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -37,6 +37,12 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
let summary = "Fuse operations on RankedTensorType in linalg dialect";
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
+ let options = [
+ Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
+ "bool", /*default=*/"false",
+ "Allow fusing linalg.tensor_reshape ops that performs unit "
+ "dimension collapsing">
+ ];
let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 4b0951e..7e89a08 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -897,9 +897,14 @@ struct FoldProducerReshapeOpByLinearization
/// generic/indexed_generic op, when the reshape op is collapsing
/// dimensions. The dimensionality of the loop in the consumer is expanded.
template <typename GenericOpTy>
-struct FoldWithProducerReshapeOpByExpansion
+class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOpTy> {
- using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+public:
+ FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
+ bool foldUnitDimReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<GenericOpTy>(context, benefit),
+ allowFoldingUnitDimReshapes(foldUnitDimReshapes) {}
LogicalResult matchAndRewrite(GenericOpTy genericOp,
PatternRewriter &rewriter) const override {
@@ -916,8 +921,9 @@ struct FoldWithProducerReshapeOpByExpansion
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
!isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
- isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
- reshapeOp.getReassociationMaps()))
+ (!allowFoldingUnitDimReshapes &&
+ isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+ reshapeOp.getReassociationMaps())))
continue;
Optional<SmallVector<Value, 1>> replacementValues =
@@ -930,6 +936,9 @@ struct FoldWithProducerReshapeOpByExpansion
}
return failure();
}
+
+private:
+ bool allowFoldingUnitDimReshapes;
};
/// Pattern to fold tensor_reshape op with its producer. The corresponding index
@@ -1134,7 +1143,8 @@ struct FusionOfTensorOpsPass
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
- populateLinalgTensorOpsFusionPatterns(patterns);
+ populateLinalgTensorOpsFusionPatterns(patterns,
+ allowFoldingUnitDimReshapes);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
@@ -1171,20 +1181,22 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
}
void mlir::populateFoldReshapeOpsByExpansionPatterns(
- RewritePatternSet &patterns) {
- patterns.add<FoldReshapeWithGenericOpByExpansion,
- FoldWithProducerReshapeOpByExpansion<GenericOp>,
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
+ patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
+ patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
- patterns.getContext());
+ patterns.getContext(), allowFoldingUnitDimReshapes);
}
-void mlir::populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns) {
+void mlir::populateLinalgTensorOpsFusionPatterns(
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
auto *context = patterns.getContext();
patterns
.add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
context);
- populateFoldReshapeOpsByExpansionPatterns(patterns);
+ populateFoldReshapeOpsByExpansionPatterns(patterns,
+ allowFoldingUnitDimReshapes);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index fbaf47c..d5dc176 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file -verify-each=0 | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file -verify-each=0 | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file -verify-each=0 | FileCheck %s --check-prefix=FOLDUNITDIM
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
@@ -300,7 +301,7 @@ func @reshape_as_consumer_permutation
%5 = addi %3, %4 : i32
%6 = index_cast %arg2 : index to i32
%7 = addi %5, %6 : i32
- linalg.yield %7 : i32
+ linalg.yield %7 : i32
} -> tensor<6x4x210xi32>
%d = linalg.tensor_reshape %c
[affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
@@ -531,3 +532,11 @@ func @unit_dim_reshape_expansion_full
// CHECK-DAG: linalg.tensor_reshape
// CHECK-DAG: linalg.init_tensor
// CHECK: linalg.generic
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
+
+// FOLDUNITDIM: func @unit_dim_reshape_expansion_full
+// FOLDUNITDIM: linalg.init_tensor
+// FOLDUNITDIM-COUNT-2: linalg.tensor_reshape
+// FOLDUNITDIM: linalg.generic
+// FOLDUNITDIM-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
+