diff options
author | James Newling <james.newling@gmail.com> | 2025-04-11 07:13:03 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-11 15:13:03 +0100 |
commit | cd85f5dbdf135347a9912dde148ec9fd325ba8c1 (patch) | |
tree | 6ffe2d888a1f4b5971fb131334c5ae780b17f013 | |
parent | a9225251c43e84d2bef0528ac61e1abf9fadae50 (diff) | |
download | llvm-cd85f5dbdf135347a9912dde148ec9fd325ba8c1.zip llvm-cd85f5dbdf135347a9912dde148ec9fd325ba8c1.tar.gz llvm-cd85f5dbdf135347a9912dde148ec9fd325ba8c1.tar.bz2 |
[mlir] canonicalizer: shape_cast(poison) -> poison (#133988)
Based on the ShapeCastConstantFolder, this pattern replaces
%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>
with
%1 = ub.poison : vector<6xf32>
---------
Signed-off-by: James Newling <james.newling@gmail.com>
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 57 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/canonicalize.mlir | 14 |
2 files changed, 39 insertions, 32 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 98d98f0..59f3b78 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -42,6 +42,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include <cassert> #include <cstdint> @@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() { } OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { + // No-op shape cast. - if (getSource().getType() == getResult().getType()) + if (getSource().getType() == getType()) return getSource(); + VectorType resultType = getType(); + // Canceling shape casts. if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) { - if (getResult().getType() == otherOp.getSource().getType()) - return otherOp.getSource(); - // Only allows valid transitive folding. - VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType()); - VectorType resultType = llvm::cast<VectorType>(getResult().getType()); + // Only allows valid transitive folding (expand/collapse dimensions). + VectorType srcType = otherOp.getSource().getType(); + if (resultType == srcType) + return otherOp.getSource(); if (srcType.getRank() < resultType.getRank()) { if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) return {}; @@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { } else { return {}; } - setOperand(otherOp.getSource()); return getResult(); } // Cancelling broadcast and shape cast ops. if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) { - if (bcastOp.getSourceType() == getType()) + if (bcastOp.getSourceType() == resultType) return bcastOp.getSource(); } + // shape_cast(constant) -> constant + if (auto splatAttr = + llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) { + return DenseElementsAttr::get(resultType, + splatAttr.getSplatValue<Attribute>()); + } + + // shape_cast(poison) -> poison + if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) { + return ub::PoisonAttr::get(getContext()); + } + return {}; } namespace { -// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. -class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { - auto constantOp = - shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>(); - if (!constantOp) - return failure(); - // Only handle splat for now. - auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue()); - if (!dense) - return failure(); - auto newAttr = - DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()), - dense.getSplatValue<Attribute>()); - rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr); - return success(); - } -}; /// Helper function that computes a new vector type based on the input vector /// type by removing the trailing one dims: @@ -5828,8 +5820,9 @@ public: void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim, - ShapeCastBroadcastFolder>(context); + results + .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index b7db8ec..72064fb 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { // ----- +// CHECK-LABEL: shape_cast_poison +// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32> +// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32> +// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> +func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) { + %poison = ub.poison : vector<5x4x2xf32> + %poison_1 = ub.poison : vector<12x2xi32> + %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32> + %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32> + return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32> +} + +// ----- + // CHECK-LABEL: extract_strided_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32> |