diff options
author | Adam Paszke <apaszke@google.com> | 2023-12-06 11:15:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-06 11:15:47 +0100 |
commit | 34df53739af2ce0ffb2625075ee2e613b278969c (patch) | |
tree | 810985ceda5ecd5a1a91140c7a010139f477373d /mlir | |
parent | 6b1aa319754e76366edd88e10034e0539710d946 (diff) | |
download | llvm-34df53739af2ce0ffb2625075ee2e613b278969c.zip llvm-34df53739af2ce0ffb2625075ee2e613b278969c.tar.gz llvm-34df53739af2ce0ffb2625075ee2e613b278969c.tar.bz2 |
Revert "[mlir][Vector] Add fold transpose(shape_cast) -> shape_cast (#73951)" (#74579)
This reverts commit f42b7615b862bb5f77981f619f92877eb20adf54.
The fold pattern is incorrect, because it does not even look at the
permutation of non-unit dims and is happy to replace a pattern such as
```
%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
%23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32>
```
with
```
%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
```
which is obviously incorrect.
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 47 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/canonicalize.mlir | 12 |
2 files changed, 1 insertions, 58 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index caffd34..c462b23 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5548,57 +5548,12 @@ public: } }; -/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just -/// permutes a unit dim from the result of the shape_cast. -class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransposeOp transpOp, - PatternRewriter &rewriter) const override { - Value transposeSrc = transpOp.getVector(); - auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>(); - if (!shapeCastOp) - return rewriter.notifyMatchFailure( - transpOp, "TransposeOp source is not ShapeCastOp"); - - auto sourceType = transpOp.getSourceVectorType(); - auto resultType = transpOp.getResultVectorType(); - - auto filterUnitDims = [](VectorType type) { - return llvm::make_filter_range( - llvm::zip_equal(type.getShape(), type.getScalableDims()), - [&](auto dim) { - auto [size, isScalable] = dim; - return size != 1 || isScalable; - }); - }; - - auto sourceWithoutUnitDims = filterUnitDims(sourceType); - auto resultWithoutUnitDims = filterUnitDims(resultType); - - // If this transpose just permutes a unit dim, then we can fold it into the - // shape_cast. - for (auto [srcDim, resDim] : - llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) { - if (srcDim != resDim) - return rewriter.notifyMatchFailure(transpOp, - "TransposeOp permutes non-unit dim"); - } - - rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType, - shapeCastOp.getSource()); - - return success(); - }; -}; - } // namespace void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast, - TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>( - context); + TransposeFolder, FoldTransposeSplat>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 6bfb477e..1021c73 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -67,18 +67,6 @@ func.func @create_mask_transpose_to_transposed_create_mask( // ----- -// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast -// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> -func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> { - // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32> - // CHECK-NOT: vector.transpose - %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32> - %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> - return %1 : vector<1x[4]xf32> -} - -// ----- - // CHECK-LABEL: extract_from_create_mask // CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> { |