diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ecf9983..88819cd9 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3269,22 +3269,24 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); - if (!isElementwise(target)) { - failed(rewriter.notifyMatchFailure( - target, "only elementwise flattening is supported")); - return emitDefaultSilenceableFailure(target); - } + if (!isElementwise(target)) + return mlir::emitSilenceableFailure(target->getLoc()) + << "only elementwise flattening is supported"; + // If rank <= 1, do nothing if (target.getNumLoops() <= 1) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } + + // Attempt to flatten all dims to one. ReassociationIndices reassociation(target.getNumLoops()); std::iota(reassociation.begin(), reassociation.end(), 0); auto maybeFlattened = collapseOpIterationDims(target, reassociation, rewriter); if (failed(maybeFlattened)) - return emitDefaultSilenceableFailure(target); + return mlir::emitSilenceableFailure(target->getLoc()) + << "attempted to flatten, but failed"; results.push_back(maybeFlattened->collapsedOp); rewriter.replaceOp(target, maybeFlattened->results); return DiagnosedSilenceableFailure::success(); |