diff options
author | Matthias Springer <me@m-sp.org> | 2023-12-07 12:05:20 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-07 12:05:20 +0900 |
commit | 986287e7f38321165c0c654f3af06e34af7b161f (patch) | |
tree | 5f75b37fd1fdb7c71aa7634b7bab9edc26c56b12 /mlir | |
parent | cdd81e3be3df65a966879abef590e36f73e7dea6 (diff) | |
download | llvm-986287e7f38321165c0c654f3af06e34af7b161f.zip llvm-986287e7f38321165c0c654f3af06e34af7b161f.tar.gz llvm-986287e7f38321165c0c654f3af06e34af7b161f.tar.bz2 |
[mlir][SparseTensor] Fix invalid API usage in patterns (#74690)
Rewrite patterns must return `success` if the IR was modified. This
commit fixes sparse tensor tests such as
`SparseTensor/sparse_fusion.mlir`,
`SparseTensor/CPU/sparse_reduce_custom.mlir`,
`SparseTensor/CPU/sparse_semiring_select.mlir` when running with
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 11 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp | 20 |
2 files changed, 20 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index f0393e4..dc5ea28 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -422,11 +422,6 @@ public: if (!controlFn(&opOperand)) continue; - // Find the producer of the operand. - FailureOr<ElementwiseOpFusionResult> fusionResult = - fuseElementwiseOps(rewriter, &opOperand); - if (failed(fusionResult)) - return rewriter.notifyMatchFailure(genericOp, "fusion failed"); Operation *producer = opOperand.get().getDefiningOp(); // Do not fuse a sparse-in/dense-out operation, as the @@ -435,6 +430,12 @@ public: !sparse_tensor::hasAnySparseResult(producer)) return failure(); + // Find the producer of the operand. + FailureOr<ElementwiseOpFusionResult> fusionResult = + fuseElementwiseOps(rewriter, &opOperand); + if (failed(fusionResult)) + return rewriter.notifyMatchFailure(genericOp, "fusion failed"); + // Perform the fusion. for (auto [origVal, replacement] : fusionResult->replacements) { rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index c94ef8b..488079c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -38,16 +38,22 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> { LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); + // Demaps non-trivial inputs. + bool changed = false; SmallVector<Value> deMappedIns(op->getOperands()); - for (Value &in : deMappedIns) - if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) + for (Value &in : deMappedIns) { + if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) { in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); + changed = true; + } + } // CRTP call. OpAdaptor adaptor(deMappedIns, op); - return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, - rewriter); + LogicalResult status = + static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter); + return changed ? success() : status; } }; @@ -452,11 +458,13 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> { } // Marks the GenericOp to avoid recursive matching. - linalgOp->setAttr(sorted, rewriter.getBoolAttr(true)); + rewriter.updateRootInPlace(linalgOp, [&]() { + linalgOp->setAttr(sorted, rewriter.getBoolAttr(true)); + }); // Already sorted. if (order.isIdentity()) - return failure(); + return success(); assert(order.isPermutation()); // `order` is orignial loop -> sorted loop map |