aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAdam Paszke <apaszke@google.com>2023-12-06 11:15:47 +0100
committerGitHub <noreply@github.com>2023-12-06 11:15:47 +0100
commit34df53739af2ce0ffb2625075ee2e613b278969c (patch)
tree810985ceda5ecd5a1a91140c7a010139f477373d /mlir
parent6b1aa319754e76366edd88e10034e0539710d946 (diff)
downloadllvm-34df53739af2ce0ffb2625075ee2e613b278969c.zip
llvm-34df53739af2ce0ffb2625075ee2e613b278969c.tar.gz
llvm-34df53739af2ce0ffb2625075ee2e613b278969c.tar.bz2
Revert "[mlir][Vector] Add fold transpose(shape_cast) -> shape_cast (#73951)" (#74579)
This reverts commit f42b7615b862bb5f77981f619f92877eb20adf54. The fold pattern is incorrect, because it does not even look at the permutation of non-unit dims and is happy to replace a pattern such as ``` %22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32> %23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32> ``` with ``` %22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32> ``` which is obviously incorrect.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp47
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir12
2 files changed, 1 insertions, 58 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index caffd34..c462b23 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5548,57 +5548,12 @@ public:
}
};
-/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
-/// permutes a unit dim from the result of the shape_cast.
-class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransposeOp transpOp,
- PatternRewriter &rewriter) const override {
- Value transposeSrc = transpOp.getVector();
- auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
- if (!shapeCastOp)
- return rewriter.notifyMatchFailure(
- transpOp, "TransposeOp source is not ShapeCastOp");
-
- auto sourceType = transpOp.getSourceVectorType();
- auto resultType = transpOp.getResultVectorType();
-
- auto filterUnitDims = [](VectorType type) {
- return llvm::make_filter_range(
- llvm::zip_equal(type.getShape(), type.getScalableDims()),
- [&](auto dim) {
- auto [size, isScalable] = dim;
- return size != 1 || isScalable;
- });
- };
-
- auto sourceWithoutUnitDims = filterUnitDims(sourceType);
- auto resultWithoutUnitDims = filterUnitDims(resultType);
-
- // If this transpose just permutes a unit dim, then we can fold it into the
- // shape_cast.
- for (auto [srcDim, resDim] :
- llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
- if (srcDim != resDim)
- return rewriter.notifyMatchFailure(transpOp,
- "TransposeOp permutes non-unit dim");
- }
-
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
- shapeCastOp.getSource());
-
- return success();
- };
-};
-
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
- context);
+ TransposeFolder, FoldTransposeSplat>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6bfb477e..1021c73 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,18 +67,6 @@ func.func @create_mask_transpose_to_transposed_create_mask(
// -----
-// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
-func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
- // CHECK-NOT: vector.transpose
- %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
- %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %1 : vector<1x[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: extract_from_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {