diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2024-03-20 21:04:06 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-20 21:04:06 +0000 |
commit | b7324b6a9c6bd43786ea853bf1a9730486b4bc88 (patch) | |
tree | a0d26e47d014d87cd9c788062659da49c93ffa2a /mlir/lib | |
parent | f6f474c4ef9694a4ca8f08d59fd112c250fb9c73 (diff) | |
download | llvm-b7324b6a9c6bd43786ea853bf1a9730486b4bc88.zip llvm-b7324b6a9c6bd43786ea853bf1a9730486b4bc88.tar.gz llvm-b7324b6a9c6bd43786ea853bf1a9730486b4bc88.tar.bz2 |
[mlir][vector] Adds pattern rewrite for maskable Ops (#83827)
Adds a generic pattern rewrite for maskable Ops, `MaskableOpRewritePattern`,
that will work for both masked and un-masked cases, e.g. for both:
* `vector.mask {vector.contract}` (masked), and
* `vector.contract` (not masked).
This helps to reduce code-duplication and standardise how we implement such
patterns.
Fixes #78787
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp | 179 |
1 files changed, 85 insertions, 94 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 0eaf9f7..ba1c968 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -41,7 +41,6 @@ using namespace mlir::vector; //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// - // Helper to find an index in an affine map. static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { @@ -226,9 +225,9 @@ namespace { /// This only kicks in when VectorTransformsOptions is set to OuterProduct and /// the vector.contract op is a row-major matrix multiply. class ContractionOpToMatmulOpLowering - : public OpRewritePattern<vector::ContractionOp> { + : public vector::MaskableOpRewritePattern<vector::ContractionOp> { public: - using OpRewritePattern::OpRewritePattern; + using MaskableOpRewritePattern::MaskableOpRewritePattern; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; @@ -241,12 +240,13 @@ public: vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit = 1, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern<vector::ContractionOp>(context, benefit), + : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(std::move(constraint)) {} - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; + FailureOr<Value> + matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override; private: /// Options to control the vector patterns. @@ -270,9 +270,9 @@ private: /// This only kicks in when VectorTransformsOptions is set to OuterProduct and /// the vector.contract op is a row-major matrix multiply. class ContractionOpToOuterProductOpLowering - : public OpRewritePattern<vector::ContractionOp> { + : public MaskableOpRewritePattern<vector::ContractionOp> { public: - using OpRewritePattern::OpRewritePattern; + using MaskableOpRewritePattern::MaskableOpRewritePattern; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; @@ -285,12 +285,13 @@ public: vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit = 1, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern<vector::ContractionOp>(context, benefit), + : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(std::move(constraint)) {} - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; + FailureOr<Value> + matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override; private: /// Options to control the vector patterns. @@ -317,9 +318,9 @@ private: /// This only kicks in when VectorTransformsOptions is set to Dot and /// the vector.contract op is a row-major matmul or matvec. class ContractionOpToDotLowering - : public OpRewritePattern<vector::ContractionOp> { + : public MaskableOpRewritePattern<vector::ContractionOp> { public: - using OpRewritePattern::OpRewritePattern; + using MaskableOpRewritePattern::MaskableOpRewritePattern; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; @@ -332,11 +333,12 @@ public: vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit = 1, const FilterConstraintType &constraint = defaultFilter) - : OpRewritePattern<vector::ContractionOp>(context, benefit), + : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; + FailureOr<Value> + matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override; private: /// Options to control the vector patterns. @@ -358,9 +360,10 @@ private: /// /// This only kicks in when either VectorTransformsOptions is set /// to Dot or when other contraction patterns fail. -class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> { +class ContractionOpLowering + : public MaskableOpRewritePattern<vector::ContractionOp> { public: - using OpRewritePattern::OpRewritePattern; + using MaskableOpRewritePattern::MaskableOpRewritePattern; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; @@ -371,12 +374,13 @@ public: ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit = 1, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern<vector::ContractionOp>(context, benefit), + : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(std::move(constraint)) {} - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; + FailureOr<Value> + matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override; private: /// Options to control the vector patterns. @@ -634,8 +638,10 @@ private: /// /// This only kicks in when VectorTransformsOptions is set to OuterProduct but /// otherwise supports any layout permutation of the matrix-multiply. -LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( - vector::ContractionOp op, PatternRewriter &rewriter) const { +FailureOr<Value> +ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp( + vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const { if (vectorTransformOptions.vectorContractLowering != vector::VectorContractLowering::OuterProduct) return failure(); @@ -643,43 +649,25 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( if (failed(filter(op))) return failure(); - // Vector mask setup. - OpBuilder::InsertionGuard guard(rewriter); - auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); - Operation *rootOp; - if (maskableOp.isMasked()) { - rewriter.setInsertionPoint(maskableOp.getMaskingOp()); - rootOp = maskableOp.getMaskingOp(); - } else { - rootOp = op; - } - UnrolledOuterProductGenerator e(rewriter, op); FailureOr<Value> matmatRes = e.matmat(); if (succeeded(matmatRes)) { - rewriter.replaceOp(rootOp, *matmatRes); - return success(); + return matmatRes; } FailureOr<Value> matvecRes = e.matvec(); if (succeeded(matvecRes)) { - rewriter.replaceOp(rootOp, *matvecRes); - return success(); - } - FailureOr<Value> tmatvecRes = e.tmatvec(); - if (succeeded(tmatvecRes)) { - rewriter.replaceOp(rootOp, *tmatvecRes); - return success(); + return matvecRes; } - return failure(); + FailureOr<Value> tmatvecRes = e.tmatvec(); + return tmatvecRes; } -LogicalResult -ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const { +FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp( + vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const { // TODO: Support vector.mask. - auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); - if (maskableOp.isMasked()) + if (maskOp) return failure(); if (failed(filter(op))) @@ -788,15 +776,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, } if (auto acc = op.getAcc()) res = createAdd(op.getLoc(), res, acc, isInt, rewriter); - rewriter.replaceOp(op, res); - return success(); + return res; } /// Lower vector.contract with all size one reduction dimensions to /// elementwise ops when possible. struct ContractOpToElementwise - : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + : public MaskableOpRewritePattern<vector::ContractionOp> { + using MaskableOpRewritePattern::MaskableOpRewritePattern; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; static LogicalResult defaultFilter(vector::ContractionOp op) { @@ -806,14 +793,15 @@ struct ContractOpToElementwise vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit = 1, const FilterConstraintType &constraint = defaultFilter) - : OpRewritePattern<vector::ContractionOp>(context, benefit), + : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} - LogicalResult matchAndRewrite(vector::ContractionOp contractOp, - PatternRewriter &rewriter) const override { + FailureOr<Value> + matchAndRewriteMaskableOp(vector::ContractionOp contractOp, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { // TODO: Support vector.mask. - auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation()); - if (maskableOp.isMasked()) + if (maskOp) return failure(); if (failed(filter(contractOp))) @@ -903,8 +891,10 @@ struct ContractOpToElementwise std::optional<Value> result = createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), contractOp.getKind(), rewriter, isInt); - rewriter.replaceOp(contractOp, {*result}); - return success(); + if (result) + return *result; + + return failure(); } private: @@ -930,9 +920,9 @@ private: // TODO: break down into transpose/reshape/cast ops // when they become available to avoid code dup // TODO: investigate lowering order impact on performance -LogicalResult -ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const { +FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp( + vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const { if (failed(filter(op))) return failure(); @@ -951,29 +941,36 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, // TODO: implement benefits, cost models. MLIRContext *ctx = op.getContext(); + ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); - if (succeeded(pat1.matchAndRewrite(op, rewriter))) - return success(); + FailureOr<Value> newVal1 = + pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter); + if (!failed(newVal1)) + return newVal1; + ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); - if (succeeded(pat2.matchAndRewrite(op, rewriter))) - return success(); + FailureOr<Value> newVal2 = + pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter); + if (!failed(newVal2)) + return newVal2; + ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); - if (succeeded(pat3.matchAndRewrite(op, rewriter))) - return success(); + FailureOr<Value> newVal3 = + pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter); + if (!failed(newVal3)) + return newVal3; + ContractOpToElementwise pat4(vectorTransformOptions, ctx); - if (succeeded(pat4.matchAndRewrite(op, rewriter))) - return success(); + FailureOr<Value> newVal4 = + pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter); + if (!failed(newVal4)) + return newVal4; // Vector mask setup. - OpBuilder::InsertionGuard guard(rewriter); - Operation *rootOp = op; - Value mask; - if (op.isMasked()) { - rewriter.setInsertionPoint(op.getMaskingOp()); - rootOp = op.getMaskingOp(); - mask = op.getMaskingOp().getMask(); - } + Value mask; + if (maskOp) + mask = maskOp.getMask(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); if (!batchDimMap.empty()) { @@ -982,8 +979,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); if (failed(newOp)) return failure(); - rewriter.replaceOp(rootOp, *newOp); - return success(); + return newOp; } // Collect contracting dimensions. @@ -1003,8 +999,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); if (failed(newOp)) return failure(); - rewriter.replaceOp(rootOp, *newOp); - return success(); + return newOp; } } @@ -1015,8 +1010,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); if (failed(newOp)) return failure(); - rewriter.replaceOp(rootOp, *newOp); - return success(); + return newOp; } } @@ -1025,8 +1019,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, auto newOp = lowerReduction(rewriter, op, mask); if (failed(newOp)) return failure(); - rewriter.replaceOp(rootOp, *newOp); - return success(); + return newOp; } return failure(); @@ -1291,12 +1284,11 @@ public: /// This only kicks in when VectorTransformsOptions is set to `Matmul`. /// vector.transpose operations are inserted if the vector.contract op is not a /// row-major matrix multiply. -LogicalResult -ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rew) const { +FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( + vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rew) const { // TODO: Support vector.mask. - auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); - if (maskableOp.isMasked()) + if (maskOp) return failure(); if (vectorTransformOptions.vectorContractLowering != @@ -1379,8 +1371,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, : static_cast<Value>( rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); - rew.replaceOp(op, res); - return success(); + return res; } } // namespace |