aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp55
1 files changed, 29 insertions, 26 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea4a02f..2005179 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1612,7 +1612,27 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
-/// For vectors with either leading or trailing unit dim, replaces:
+// Scalable unit dimensions are not supported. Folding such dimensions would
+// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
+// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
+// future.
+static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
+ auto inVecShape = inVecTy.getShape();
+ SmallVector<int64_t> newShape;
+ SmallVector<bool> newScalableDims;
+ for (auto [dim, isScalable] :
+ llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
+ if (dim == 1 && !isScalable)
+ continue;
+
+ newShape.push_back(dim);
+ newScalableDims.push_back(isScalable);
+ }
+
+ return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
+}
+
+/// For vectors with at least an unit dim, replaces:
/// elementwise(a, b)
/// with:
/// sc_a = shape_cast(a)
@@ -1624,20 +1644,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
/// required to be rank > 1.
///
/// Ex:
-/// ```
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
-/// ```
///
/// gets converted to:
///
-/// ```
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
-/// ```
///
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
/// `%cast`.
@@ -1657,42 +1673,29 @@ struct DropUnitDimFromElementwiseOps final
// guaranteed to have identical shapes (with some exceptions such as
// `arith.select`) and it suffices to only check one of them.
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
- if (!sourceVectorType)
- return failure();
- if (sourceVectorType.getRank() < 2)
- return failure();
-
- bool hasTrailingDimUnitFixed =
- ((sourceVectorType.getShape().back() == 1) &&
- (!sourceVectorType.getScalableDims().back()));
- bool hasLeadingDimUnitFixed =
- ((sourceVectorType.getShape().front() == 1) &&
- (!sourceVectorType.getScalableDims().front()));
- if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+ if (!sourceVectorType || sourceVectorType.getRank() < 2)
return failure();
- // Drop leading/trailing unit dim by applying vector.shape_cast to all
- // operands
- int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
- VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
+ auto newVType = dropNonScalableUnitDimFromType(opVectorType);
+ if (newVType == opVectorType)
+ return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
+
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}
VectorType newResultVectorType =
- VectorType::Builder(resultVectorType).dropDim(dim);
- // Create an updated elementwise Op without leading/trailing unit dim
+ dropNonScalableUnitDimFromType(resultVectorType);
+ // Create an updated elementwise Op without unit dim.
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
newResultVectorType, op->getAttrs());
- // Restore the leading/trailing unit dim by applying vector.shape_cast
- // to the result
+ // Restore the unit dim by applying vector.shape_cast to the result.
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));