aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorAndrzej WarzyƄski <andrzej.warzynski@arm.com>2024-03-20 21:04:06 +0000
committerGitHub <noreply@github.com>2024-03-20 21:04:06 +0000
commitb7324b6a9c6bd43786ea853bf1a9730486b4bc88 (patch)
treea0d26e47d014d87cd9c788062659da49c93ffa2a /mlir/lib
parentf6f474c4ef9694a4ca8f08d59fd112c250fb9c73 (diff)
downloadllvm-b7324b6a9c6bd43786ea853bf1a9730486b4bc88.zip
llvm-b7324b6a9c6bd43786ea853bf1a9730486b4bc88.tar.gz
llvm-b7324b6a9c6bd43786ea853bf1a9730486b4bc88.tar.bz2
[mlir][vector] Adds pattern rewrite for maskable Ops (#83827)
Adds a generic pattern rewrite for maskable Ops, `MaskableOpRewritePattern`, that will work for both masked and un-masked cases, e.g. for both: * `vector.mask {vector.contract}` (masked), and * `vector.contract` (not masked). This helps to reduce code-duplication and standardise how we implement such patterns. Fixes #78787
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp179
1 files changed, 85 insertions, 94 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 0eaf9f7..ba1c968 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -41,7 +41,6 @@ using namespace mlir::vector;
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
-
// Helper to find an index in an affine map.
static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -226,9 +225,9 @@ namespace {
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
- : public OpRewritePattern<vector::ContractionOp> {
+ : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -241,12 +240,13 @@ public:
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -270,9 +270,9 @@ private:
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
- : public OpRewritePattern<vector::ContractionOp> {
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -285,12 +285,13 @@ public:
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -317,9 +318,9 @@ private:
/// This only kicks in when VectorTransformsOptions is set to Dot and
/// the vector.contract op is a row-major matmul or matvec.
class ContractionOpToDotLowering
- : public OpRewritePattern<vector::ContractionOp> {
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -332,11 +333,12 @@ public:
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -358,9 +360,10 @@ private:
///
/// This only kicks in when either VectorTransformsOptions is set
/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+class ContractionOpLowering
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -371,12 +374,13 @@ public:
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -634,8 +638,10 @@ private:
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
return failure();
@@ -643,43 +649,25 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
if (failed(filter(op)))
return failure();
- // Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
- Operation *rootOp;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- } else {
- rootOp = op;
- }
-
UnrolledOuterProductGenerator e(rewriter, op);
FailureOr<Value> matmatRes = e.matmat();
if (succeeded(matmatRes)) {
- rewriter.replaceOp(rootOp, *matmatRes);
- return success();
+ return matmatRes;
}
FailureOr<Value> matvecRes = e.matvec();
if (succeeded(matvecRes)) {
- rewriter.replaceOp(rootOp, *matvecRes);
- return success();
- }
- FailureOr<Value> tmatvecRes = e.tmatvec();
- if (succeeded(tmatvecRes)) {
- rewriter.replaceOp(rootOp, *tmatvecRes);
- return success();
+ return matvecRes;
}
- return failure();
+ FailureOr<Value> tmatvecRes = e.tmatvec();
+ return tmatvecRes;
}
-LogicalResult
-ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
// TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
+ if (maskOp)
return failure();
if (failed(filter(op)))
@@ -788,15 +776,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
}
if (auto acc = op.getAcc())
res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
- rewriter.replaceOp(op, res);
- return success();
+ return res;
}
/// Lower vector.contract with all size one reduction dimensions to
/// elementwise ops when possible.
struct ContractOpToElementwise
- : public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
@@ -806,14 +793,15 @@ struct ContractOpToElementwise
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
- LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
- if (maskableOp.isMasked())
+ if (maskOp)
return failure();
if (failed(filter(contractOp)))
@@ -903,8 +891,10 @@ struct ContractOpToElementwise
std::optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
contractOp.getKind(), rewriter, isInt);
- rewriter.replaceOp(contractOp, {*result});
- return success();
+ if (result)
+ return *result;
+
+ return failure();
}
private:
@@ -930,9 +920,9 @@ private:
// TODO: break down into transpose/reshape/cast ops
// when they become available to avoid code dup
// TODO: investigate lowering order impact on performance
-LogicalResult
-ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
if (failed(filter(op)))
return failure();
@@ -951,29 +941,36 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();
+
ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
- if (succeeded(pat1.matchAndRewrite(op, rewriter)))
- return success();
+ FailureOr<Value> newVal1 =
+ pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+ if (!failed(newVal1))
+ return newVal1;
+
ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
- if (succeeded(pat2.matchAndRewrite(op, rewriter)))
- return success();
+ FailureOr<Value> newVal2 =
+ pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+ if (!failed(newVal2))
+ return newVal2;
+
ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
- if (succeeded(pat3.matchAndRewrite(op, rewriter)))
- return success();
+ FailureOr<Value> newVal3 =
+ pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+ if (!failed(newVal3))
+ return newVal3;
+
ContractOpToElementwise pat4(vectorTransformOptions, ctx);
- if (succeeded(pat4.matchAndRewrite(op, rewriter)))
- return success();
+ FailureOr<Value> newVal4 =
+ pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+ if (!failed(newVal4))
+ return newVal4;
// Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- Operation *rootOp = op;
- Value mask;
- if (op.isMasked()) {
- rewriter.setInsertionPoint(op.getMaskingOp());
- rootOp = op.getMaskingOp();
- mask = op.getMaskingOp().getMask();
- }
+ Value mask;
+ if (maskOp)
+ mask = maskOp.getMask();
// Find first batch dimension in LHS/RHS, and lower when found.
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
if (!batchDimMap.empty()) {
@@ -982,8 +979,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
+ return newOp;
}
// Collect contracting dimensions.
@@ -1003,8 +999,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
+ return newOp;
}
}
@@ -1015,8 +1010,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
+ return newOp;
}
}
@@ -1025,8 +1019,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
auto newOp = lowerReduction(rewriter, op, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
+ return newOp;
}
return failure();
@@ -1291,12 +1284,11 @@ public:
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
/// vector.transpose operations are inserted if the vector.contract op is not a
/// row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rew) const {
+FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rew) const {
// TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
+ if (maskOp)
return failure();
if (vectorTransformOptions.vectorContractLowering !=
@@ -1379,8 +1371,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
: static_cast<Value>(
rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
- rew.replaceOp(op, res);
- return success();
+ return res;
}
} // namespace