aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHugo Trachino <hugo.trachino@huawei.com>2024-06-20 10:43:23 +0100
committerGitHub <noreply@github.com>2024-06-20 10:43:23 +0100
commit2c06fb899966b49ff0fe4adf55fceb7d1941fbca (patch)
tree3eaebcd131b0263992383472a7d566d28c576916
parent1002c08c646d8c85fb63a54140a00c642f317b28 (diff)
downloadllvm-2c06fb899966b49ff0fe4adf55fceb7d1941fbca.zip
llvm-2c06fb899966b49ff0fe4adf55fceb7d1941fbca.tar.gz
llvm-2c06fb899966b49ff0fe4adf55fceb7d1941fbca.tar.bz2
[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. (#92934)
Generalizes `DropUnitDimFromElementwiseOps` to support inner unit dimensions. This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts. discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa). --------- Co-authored-by: Benjamin Maxwell <macdue@dueutil.tech>
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp55
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-flatten.mlir36
2 files changed, 65 insertions, 26 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea4a02f..2005179 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1612,7 +1612,27 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
-/// For vectors with either leading or trailing unit dim, replaces:
+// Scalable unit dimensions are not supported. Folding such dimensions would
+// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
+// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
+// future.
+static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
+ auto inVecShape = inVecTy.getShape();
+ SmallVector<int64_t> newShape;
+ SmallVector<bool> newScalableDims;
+ for (auto [dim, isScalable] :
+ llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
+ if (dim == 1 && !isScalable)
+ continue;
+
+ newShape.push_back(dim);
+ newScalableDims.push_back(isScalable);
+ }
+
+ return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
+}
+
+/// For vectors with at least an unit dim, replaces:
/// elementwise(a, b)
/// with:
/// sc_a = shape_cast(a)
@@ -1624,20 +1644,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
/// required to be rank > 1.
///
/// Ex:
-/// ```
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
-/// ```
///
/// gets converted to:
///
-/// ```
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
-/// ```
///
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
/// `%cast`.
@@ -1657,42 +1673,29 @@ struct DropUnitDimFromElementwiseOps final
// guaranteed to have identical shapes (with some exceptions such as
// `arith.select`) and it suffices to only check one of them.
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
- if (!sourceVectorType)
- return failure();
- if (sourceVectorType.getRank() < 2)
- return failure();
-
- bool hasTrailingDimUnitFixed =
- ((sourceVectorType.getShape().back() == 1) &&
- (!sourceVectorType.getScalableDims().back()));
- bool hasLeadingDimUnitFixed =
- ((sourceVectorType.getShape().front() == 1) &&
- (!sourceVectorType.getScalableDims().front()));
- if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+ if (!sourceVectorType || sourceVectorType.getRank() < 2)
return failure();
- // Drop leading/trailing unit dim by applying vector.shape_cast to all
- // operands
- int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
- VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
+ auto newVType = dropNonScalableUnitDimFromType(opVectorType);
+ if (newVType == opVectorType)
+ return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
+
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}
VectorType newResultVectorType =
- VectorType::Builder(resultVectorType).dropDim(dim);
- // Create an updated elementwise Op without leading/trailing unit dim
+ dropNonScalableUnitDimFromType(resultVectorType);
+ // Create an updated elementwise Op without unit dim.
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
newResultVectorType, op->getAttrs());
- // Restore the leading/trailing unit dim by applying vector.shape_cast
- // to the result
+ // Restore the unit dim by applying vector.shape_cast to the result.
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d7365d2..42bf720 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -499,6 +499,42 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// -----
+func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
+ %arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
+ %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
+ %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
+ %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
+ return %res : vector<8x3xf128>
+}
+
+// CHECK-LABEL: func.func @fold_inner_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
+// CHECK: return %[[VAL_4]] : vector<8x3xf128>
+
+// -----
+
+func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
+ %arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
+ %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
+ %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
+ %res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
+ return %res : vector<8x[1]x3xf128>
+}
+
+// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
+// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>
+
+// -----
+
func.func @negative_out_of_bound_transfer_read(
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index