aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorHugo Trachino <hugo.trachino@huawei.com>2024-06-21 13:34:37 +0100
committerGitHub <noreply@github.com>2024-06-21 13:34:37 +0100
commit9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c (patch)
tree8f03454721d4410b2d68da78711bf7721f5c66c8 /mlir
parent138ea7d1fb82c2525da8dcc2f8ea73eae7b25f25 (diff)
downloadllvm-9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c.zip
llvm-9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c.tar.gz
llvm-9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c.tar.bz2
[mlir][vector] Add ElementwiseToOuterproduct (#93664)
1D multi-reduction are lowered to arith which can prevent some optimisations. I propose `ElementwiseToOuterproduct` matching a series of ops to generate `vector.outerproduct`. As part of some `ElementwiseToVectorOpsPatterns`, it could allow to fuse other elementwiseOps to vector dialect. Originally discussed https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24. quote @MacDue ``` %lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32> %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32> %rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32> %mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32> ``` Can be rewritten as: ``` %mul = vector.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32> ``` --------- Co-authored-by: Han-Chung Wang <hanhan0912@gmail.com>
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.h4
-rw-r--r--mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td11
-rw-r--r--mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp5
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp85
-rw-r--r--mlir/test/Dialect/Vector/transform-vector.mlir38
5 files changed, 143 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4603953..ac55433 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
/// into vector contract for the backends with native support.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns that fold elementwise op on vectors to the vector
+/// dialect.
+void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
+
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c91e8fb..820a187 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -406,6 +406,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.elementwise_to_vector",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect a set of patterns that fold elementwise op on vectors to the vector
+ dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.reduction_to_contract",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 2396026..2e9aa88 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
vector::populateFoldArithExtensionPatterns(patterns);
}
+void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateElementwiseToVectorOpsPatterns(patterns);
+}
+
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorReductionToContractPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b824508..eac6db5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1813,6 +1813,84 @@ private:
unsigned maxNumElementsToExtract = 0;
};
+/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
+/// B)`.
+/// Example:
+/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
+/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
+/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
+/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
+///
+/// Becomes :
+///
+/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
+///
+/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
+/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
+/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
+/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
+template <typename MulOpType>
+struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
+ using OpRewritePattern<MulOpType>::OpRewritePattern;
+ // Returns whether a vector.broadcast matches requirements for an outerproduct
+ // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
+ bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
+ // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
+ // shape_casts/broadcasts which does not belong in this pattern.
+ if (!broadcastOp.computeBroadcastedUnitDims().empty())
+ return false;
+ // Avoid broadcast like f32 or vector<f32> -> ResType
+ auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ return srcType && srcType.getRank() != 2;
+ }
+
+ LogicalResult matchAndRewrite(MulOpType mulOp,
+ PatternRewriter &rewriter) const override {
+ auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
+ if (!resType)
+ return failure();
+ if (resType.getRank() != 2)
+ return failure();
+ /// If operandA can be written as tr(broadcast(A)) and operandB as
+ /// broadcast(B) where broadcasts are 1D-to-2D, create and return
+ /// vector.outerproduct(A, B). Returns failure() otherwise.
+ auto matchOuterProduct =
+ [&](Value operandA,
+ Value operandB) -> FailureOr<vector::OuterProductOp> {
+ auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
+ if (!transposedLhs)
+ return failure();
+ // Fail unless this is a true 2-D matrix transpose.
+ ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
+ if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
+ return failure();
+
+ auto broadcastedLhs =
+ transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
+ return failure();
+
+ auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
+ return failure();
+
+ return rewriter.create<vector::OuterProductOp>(
+ mulOp->getLoc(), resType, broadcastedLhs.getSource(),
+ broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
+ };
+
+ Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
+ auto maybeOuterP = matchOuterProduct(lhs, rhs);
+ // Handle commutativity, the transposed op is the outerproduct LHS.
+ if (failed(maybeOuterP))
+ maybeOuterP = matchOuterProduct(rhs, lhs);
+ if (failed(maybeOuterP))
+ return failure();
+ rewriter.replaceOp(mulOp, maybeOuterP->getResult());
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1900,6 +1978,13 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
maxNumElementsToExtract, benefit);
}
+void mlir::vector::populateElementwiseToVectorOpsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
+ FoldArithToVectorOuterProduct<arith::MulIOp>>(
+ patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 75b29e2..4b38db7 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32
+// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
+// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
+func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+ %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+ %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+ %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+ %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
+ return %mul: vector<[4]x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32
+// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> {
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32>
+// CHECK: return %[[RES]] : vector<8x16xf32>
+func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> {
+ %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32>
+ %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+ %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32>
+ %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32>
+ return %mul: vector<8x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.elementwise_to_vector
+ } : !transform.any_op
+ transform.yield
+ }
+}