aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-12-07 12:05:20 +0900
committerGitHub <noreply@github.com>2023-12-07 12:05:20 +0900
commit986287e7f38321165c0c654f3af06e34af7b161f (patch)
tree5f75b37fd1fdb7c71aa7634b7bab9edc26c56b12 /mlir
parentcdd81e3be3df65a966879abef590e36f73e7dea6 (diff)
downloadllvm-986287e7f38321165c0c654f3af06e34af7b161f.zip
llvm-986287e7f38321165c0c654f3af06e34af7b161f.tar.gz
llvm-986287e7f38321165c0c654f3af06e34af7b161f.tar.bz2
[mlir][SparseTensor] Fix invalid API usage in patterns (#74690)
Rewrite patterns must return `success` if the IR was modified. This commit fixes sparse tensor tests such as `SparseTensor/sparse_fusion.mlir`, `SparseTensor/CPU/sparse_reduce_custom.mlir`, `SparseTensor/CPU/sparse_semiring_select.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp11
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp20
2 files changed, 20 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e4..dc5ea28 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -422,11 +422,6 @@ public:
if (!controlFn(&opOperand))
continue;
- // Find the producer of the operand.
- FailureOr<ElementwiseOpFusionResult> fusionResult =
- fuseElementwiseOps(rewriter, &opOperand);
- if (failed(fusionResult))
- return rewriter.notifyMatchFailure(genericOp, "fusion failed");
Operation *producer = opOperand.get().getDefiningOp();
// Do not fuse a sparse-in/dense-out operation, as the
@@ -435,6 +430,12 @@ public:
!sparse_tensor::hasAnySparseResult(producer))
return failure();
+ // Find the producer of the operand.
+ FailureOr<ElementwiseOpFusionResult> fusionResult =
+ fuseElementwiseOps(rewriter, &opOperand);
+ if (failed(fusionResult))
+ return rewriter.notifyMatchFailure(genericOp, "fusion failed");
+
// Perform the fusion.
for (auto [origVal, replacement] : fusionResult->replacements) {
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index c94ef8b..488079c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -38,16 +38,22 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
+
// Demaps non-trivial inputs.
+ bool changed = false;
SmallVector<Value> deMappedIns(op->getOperands());
- for (Value &in : deMappedIns)
- if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+ for (Value &in : deMappedIns) {
+ if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+ changed = true;
+ }
+ }
// CRTP call.
OpAdaptor adaptor(deMappedIns, op);
- return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
- rewriter);
+ LogicalResult status =
+ static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
+ return changed ? success() : status;
}
};
@@ -452,11 +458,13 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
}
// Marks the GenericOp to avoid recursive matching.
- linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+ rewriter.updateRootInPlace(linalgOp, [&]() {
+ linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+ });
// Already sorted.
if (order.isIdentity())
- return failure();
+ return success();
assert(order.isPermutation());
// `order` is orignial loop -> sorted loop map