aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMin-Yih Hsu <min.hsu@sifive.com>2025-08-08 09:25:32 -0700
committerGitHub <noreply@github.com>2025-08-08 09:25:32 -0700
commitb4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b (patch)
tree0a03e34309d4259e796561e3d45d05209d10c2e7
parent0419b459be770a428ee4b98c7613bff2be032961 (diff)
downloadllvm-b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b.zip
llvm-b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b.tar.gz
llvm-b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b.tar.bz2
[mlir][vector] Canonicalize broadcast of shape_cast (#150523)
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is compatible with broadcast's result type and the shape_cast only adds or removes ones in the leading dimensions. --------- Co-authored-by: Andrzej WarzyƄski <andrzej.warzynski@gmail.com> Co-authored-by: James Newling <james.newling@gmail.com>
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp38
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir100
2 files changed, 138 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a450056..cb4783d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2841,9 +2841,47 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ // Check type compatibility.
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> shapecastShape =
+ srcShapeCast.getResultVectorType().getShape();
+ // Trailing dimensions should be the same if shape_cast only alters the
+ // leading dimensions.
+ unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+ if (!llvm::equal(srcShape.take_back(numTrailingDims),
+ shapecastShape.take_back(numTrailingDims)))
+ return failure();
+
+ assert(all_of(srcShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ all_of(shapecastShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ "ill-formed shape_cast");
+
+ broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
+ return success();
+}
+
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getSourceType() == getResultVectorType())
return getSource();
+ if (succeeded(foldBroadcastOfShapeCast(*this)))
+ return getResult();
+
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f86fb38..4a7176e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,106 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2x1xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
+ %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32>
+ return %1 : vector<32x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2x4xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
+ %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32>
+ return %1 : vector<32x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32>
+ %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
+// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
+// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+ %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+ return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
+// CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
+// CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
+func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
+// CHECK: return %[[VAL_1]] : vector<2x4xf32>
+// CHECK: }
+func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
+ %1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32>
+ return %1 : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
+// CHECK: return %[[VAL_1]] : vector<32x2xf32>
+// CHECK: }
+func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32>
+ %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index