diff options
author | Hugo Trachino <hugo.trachino@huawei.com> | 2024-06-21 13:34:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-21 13:34:37 +0100 |
commit | 9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c (patch) | |
tree | 8f03454721d4410b2d68da78711bf7721f5c66c8 /mlir | |
parent | 138ea7d1fb82c2525da8dcc2f8ea73eae7b25f25 (diff) | |
download | llvm-9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c.zip llvm-9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c.tar.gz llvm-9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c.tar.bz2 |
[mlir][vector] Add ElementwiseToOuterproduct (#93664)
1D multi-reduction are lowered to arith which can prevent some
optimisations. I propose `ElementwiseToOuterproduct` matching a series of
ops to generate `vector.outerproduct`.
As part of some `ElementwiseToVectorOpsPatterns`, it could allow to fuse
other elementwiseOps to vector dialect.
Originally discussed
https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24.
quote @MacDue
```
%lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
%rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32>
%mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32>
```
Can be rewritten as:
```
%mul = vector.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
```
---------
Co-authored-by: Han-Chung Wang <hanhan0912@gmail.com>
Diffstat (limited to 'mlir')
5 files changed, 143 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 4603953..ac55433 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, /// into vector contract for the backends with native support. void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); +/// Collect a set of patterns that fold elementwise op on vectors to the vector +/// dialect. +void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index c91e8fb..820a187 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -406,6 +406,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect, let assemblyFormat = "attr-dict"; } +def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect, + "apply_patterns.vector.elementwise_to_vector", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Collect a set of patterns that fold elementwise op on vectors to the vector + dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect, "apply_patterns.vector.reduction_to_contract", [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2396026..2e9aa88 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( vector::populateFoldArithExtensionPatterns(patterns); } +void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateElementwiseToVectorOpsPatterns(patterns); +} + void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorReductionToContractPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index b824508..eac6db5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1813,6 +1813,84 @@ private: unsigned maxNumElementsToExtract = 0; }; +/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, +/// B)`. +/// Example: +/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> +/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to +/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to +/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> +/// +/// Becomes : +/// +/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> +/// +/// Supports only 1D-to-2D broadcasts. The following cases are not supported. +/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> +/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> +/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> +template <typename MulOpType> +struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> { + using OpRewritePattern<MulOpType>::OpRewritePattern; + // Returns whether a vector.broadcast matches requirements for an outerproduct + // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension. + bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const { + // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating + // shape_casts/broadcasts which does not belong in this pattern. + if (!broadcastOp.computeBroadcastedUnitDims().empty()) + return false; + // Avoid broadcast like f32 or vector<f32> -> ResType + auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()); + return srcType && srcType.getRank() != 2; + } + + LogicalResult matchAndRewrite(MulOpType mulOp, + PatternRewriter &rewriter) const override { + auto resType = llvm::cast<VectorType>(mulOp.getResult().getType()); + if (!resType) + return failure(); + if (resType.getRank() != 2) + return failure(); + /// If operandA can be written as tr(broadcast(A)) and operandB as + /// broadcast(B) where broadcasts are 1D-to-2D, create and return + /// vector.outerproduct(A, B). Returns failure() otherwise. + auto matchOuterProduct = + [&](Value operandA, + Value operandB) -> FailureOr<vector::OuterProductOp> { + auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>(); + if (!transposedLhs) + return failure(); + // Fail unless this is a true 2-D matrix transpose. + ArrayRef<int64_t> permutation = transposedLhs.getPermutation(); + if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0) + return failure(); + + auto broadcastedLhs = + transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>(); + if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs)) + return failure(); + + auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>(); + if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs)) + return failure(); + + return rewriter.create<vector::OuterProductOp>( + mulOp->getLoc(), resType, broadcastedLhs.getSource(), + broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD); + }; + + Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1); + auto maybeOuterP = matchOuterProduct(lhs, rhs); + // Handle commutativity, the transposed op is the outerproduct LHS. + if (failed(maybeOuterP)) + maybeOuterP = matchOuterProduct(rhs, lhs); + if (failed(maybeOuterP)) + return failure(); + rewriter.replaceOp(mulOp, maybeOuterP->getResult()); + return success(); + } +}; + } // namespace void mlir::vector::populateFoldArithExtensionPatterns( @@ -1900,6 +1978,13 @@ void mlir::vector::populateBreakDownVectorReductionPatterns( maxNumElementsToExtract, benefit); } +void mlir::vector::populateElementwiseToVectorOpsPatterns( + RewritePatternSet &patterns) { + patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>, + FoldArithToVectorOuterProduct<arith::MulIOp>>( + patterns.getContext()); +} + //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 75b29e2..4b38db7 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32 +// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, +// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> { +// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32> +// CHECK: return %[[RES]] : vector<[4]x[4]xi32> +func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { + %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32> + %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> + %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32> + %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32> + return %mul: vector<[4]x[4]xi32> +} + +// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32 +// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>, +// CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> { +// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32> +// CHECK: return %[[RES]] : vector<8x16xf32> +func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> { + %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32> + %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32> + %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32> + %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32> + return %mul: vector<8x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.elementwise_to_vector + } : !transform.any_op + transform.yield + } +} |