diff options
author | Min-Yih Hsu <min.hsu@sifive.com> | 2025-08-08 09:25:32 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-08-08 09:25:32 -0700 |
commit | b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b (patch) | |
tree | 0a03e34309d4259e796561e3d45d05209d10c2e7 | |
parent | 0419b459be770a428ee4b98c7613bff2be032961 (diff) | |
download | llvm-b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b.zip llvm-b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b.tar.gz llvm-b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b.tar.bz2 |
[mlir][vector] Canonicalize broadcast of shape_cast (#150523)
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is
compatible with broadcast's result type and the shape_cast only adds or removes ones in the leading dimensions.
---------
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski@gmail.com>
Co-authored-by: James Newling <james.newling@gmail.com>
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 38 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/canonicalize.mlir | 100 |
2 files changed, 138 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a450056..cb4783d 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2841,9 +2841,47 @@ LogicalResult BroadcastOp::verify() { llvm_unreachable("unexpected vector.broadcast op error"); } +// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible +// with broadcast's result type and shape_cast only adds or removes ones in the +// leading dimensions. +static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { + auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); + if (!srcShapeCast) + return failure(); + + VectorType srcType = srcShapeCast.getSourceVectorType(); + VectorType destType = broadcastOp.getResultVectorType(); + // Check type compatibility. + if (vector::isBroadcastableTo(srcType, destType) != + BroadcastableToResult::Success) + return failure(); + + ArrayRef<int64_t> srcShape = srcType.getShape(); + ArrayRef<int64_t> shapecastShape = + srcShapeCast.getResultVectorType().getShape(); + // Trailing dimensions should be the same if shape_cast only alters the + // leading dimensions. + unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size()); + if (!llvm::equal(srcShape.take_back(numTrailingDims), + shapecastShape.take_back(numTrailingDims))) + return failure(); + + assert(all_of(srcShape.drop_back(numTrailingDims), + [](int64_t E) { return E == 1; }) && + all_of(shapecastShape.drop_back(numTrailingDims), + [](int64_t E) { return E == 1; }) && + "ill-formed shape_cast"); + + broadcastOp.getSourceMutable().assign(srcShapeCast.getSource()); + return success(); +} + OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { if (getSourceType() == getResultVectorType()) return getSource(); + if (succeeded(foldBroadcastOfShapeCast(*this))) + return getResult(); + if (!adaptor.getSource()) return {}; auto vectorType = getResultVectorType(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index f86fb38..4a7176e 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1168,6 +1168,106 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) // ----- +// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim +// CHECK-NOT: vector.shape_cast +// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32> +func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> { + %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32> + %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32> + return %1 : vector<32x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> { +// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32> +// CHECK: return %[[VAL_0]] : vector<32x2x1xf32> +// CHECK: } +func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> { + %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32> + %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32> + return %1 : vector<32x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> { +// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32> +// CHECK: return %[[VAL_0]] : vector<32x2x4xf32> +// CHECK: } +func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> { + %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32> + %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32> + return %1 : vector<32x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim( +// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> { +// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32> +// CHECK: return %[[VAL_0]] : vector<32x2xf32> +// CHECK: } +func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> { + %0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32> + %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32> + return %1 : vector<32x2xf32> +} + +// ----- + +// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape +// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32> +// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32> +func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> { + %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32> + %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32> + return %1 : vector<2x4x16xf32> +} + +// ----- + +// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims +// CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32> +// CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32> +func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> { + %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32> + %1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32> + return %1 : vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim( +// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> { +// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32> +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32> +// CHECK: return %[[VAL_1]] : vector<2x4xf32> +// CHECK: } +func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> { + %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32> + %1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32> + return %1 : vector<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> { +// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32> +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32> +// CHECK: return %[[VAL_1]] : vector<32x2xf32> +// CHECK: } +func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> { + %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32> + %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32> + return %1 : vector<32x2xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfer_masks func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index |