diff options
author | William Moses <gh@wsmoses.com> | 2025-05-22 14:13:32 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-22 14:13:32 -0500 |
commit | 6a8dde04a07287f837bbabeb93e23e47af366d3d (patch) | |
tree | f5b31b078f67eca8308668397b30355dbd2b4b47 | |
parent | 4a6b1fb9dac5bb3a065b1e6e5192c4c7964af7e0 (diff) | |
download | llvm-6a8dde04a07287f837bbabeb93e23e47af366d3d.zip llvm-6a8dde04a07287f837bbabeb93e23e47af366d3d.tar.gz llvm-6a8dde04a07287f837bbabeb93e23e47af366d3d.tar.bz2 |
[MLIR] Change getBackwardSlice to return a logicalresult rather than crash (#140961)
The current implementation of getBackwardSlice will crash if an
operation in the dependency chain is defined by an operation with
multiple regions or blocks. Crashing is bad (and forbids many analyses
from using getBackwardSlice, as well as causing existing users of
getBackwardSlice to fail for IR with this property).
This PR instead causes the analysis to return a failure, rather than
crash in the cases it cannot compute the full slice
---------
Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com>
-rw-r--r-- | mlir/include/mlir/Analysis/SliceAnalysis.h | 12 | ||||
-rw-r--r-- | mlir/include/mlir/Query/Matcher/SliceMatchers.h | 3 | ||||
-rw-r--r-- | mlir/lib/Analysis/SliceAnalysis.cpp | 58 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 3 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/RegionUtils.cpp | 6 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp | 4 | ||||
-rw-r--r-- | mlir/test/lib/IR/TestSlicing.cpp | 3 |
10 files changed, 67 insertions, 39 deletions
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index 3b731e8..d082d2d 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -138,13 +138,17 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, /// Assuming all local orders match the numbering order: /// {1, 2, 5, 3, 4, 6} /// -void getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice, - const BackwardSliceOptions &options = {}); +/// This function returns whether the backwards slice was able to be +/// successfully computed, and failure if it was unable to determine the slice. +LogicalResult getBackwardSlice(Operation *op, + SetVector<Operation *> *backwardSlice, + const BackwardSliceOptions &options = {}); /// Value-rooted version of `getBackwardSlice`. Return the union of all backward /// slices for the op defining or owning the value `root`. -void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, - const BackwardSliceOptions &options = {}); +LogicalResult getBackwardSlice(Value root, + SetVector<Operation *> *backwardSlice, + const BackwardSliceOptions &options = {}); /// Iteratively computes backward slices and forward slices until /// a fixed point is reached. Returns an `SetVector<Operation *>` which diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h index 1b0e4c3..40a39d2 100644 --- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h +++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h @@ -112,7 +112,8 @@ bool BackwardSliceMatcher<Matcher>::matches( } return true; }; - getBackwardSlice(rootOp, &backwardSlice, options); + LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options); + assert(result.succeeded() && "expected backward slice to succeed"); return options.inclusive ? backwardSlice.size() > 1 : backwardSlice.size() >= 1; } diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 5aebb19..12b9d3a 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -80,25 +80,25 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, forwardSlice->insert(v.rbegin(), v.rend()); } -static void getBackwardSliceImpl(Operation *op, - SetVector<Operation *> *backwardSlice, - const BackwardSliceOptions &options) { +static LogicalResult getBackwardSliceImpl(Operation *op, + SetVector<Operation *> *backwardSlice, + const BackwardSliceOptions &options) { if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) - return; + return success(); // Evaluate whether we should keep this def. // This is useful in particular to implement scoping; i.e. return the // transitive backwardSlice in the current scope. if (options.filter && !options.filter(op)) - return; + return success(); auto processValue = [&](Value value) { if (auto *definingOp = value.getDefiningOp()) { if (backwardSlice->count(definingOp) == 0) - getBackwardSliceImpl(definingOp, backwardSlice, options); + return getBackwardSliceImpl(definingOp, backwardSlice, options); } else if (auto blockArg = dyn_cast<BlockArgument>(value)) { if (options.omitBlockArguments) - return; + return success(); Block *block = blockArg.getOwner(); Operation *parentOp = block->getParentOp(); @@ -106,15 +106,17 @@ static void getBackwardSliceImpl(Operation *op, // blocks of parentOp, which are not technically backward unless they flow // into us. For now, just bail. if (parentOp && backwardSlice->count(parentOp) == 0) { - assert(parentOp->getNumRegions() == 1 && - llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())); - getBackwardSliceImpl(parentOp, backwardSlice, options); + if (parentOp->getNumRegions() == 1 && + llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) { + return getBackwardSliceImpl(parentOp, backwardSlice, options); + } } - } else { - llvm_unreachable("No definingOp and not a block argument."); } + return failure(); }; + bool succeeded = true; + if (!options.omitUsesFromAbove) { llvm::for_each(op->getRegions(), [&](Region ®ion) { // Walk this region recursively to collect the regions that descend from @@ -125,36 +127,41 @@ static void getBackwardSliceImpl(Operation *op, region.walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { if (!descendents.contains(operand.get().getParentRegion())) - processValue(operand.get()); + if (!processValue(operand.get()).succeeded()) { + return WalkResult::interrupt(); + } } + return WalkResult::advance(); }); }); } llvm::for_each(op->getOperands(), processValue); backwardSlice->insert(op); + return success(succeeded); } -void mlir::getBackwardSlice(Operation *op, - SetVector<Operation *> *backwardSlice, - const BackwardSliceOptions &options) { - getBackwardSliceImpl(op, backwardSlice, options); +LogicalResult mlir::getBackwardSlice(Operation *op, + SetVector<Operation *> *backwardSlice, + const BackwardSliceOptions &options) { + LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options); if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't // want it in the results. backwardSlice->remove(op); } + return result; } -void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, - const BackwardSliceOptions &options) { +LogicalResult mlir::getBackwardSlice(Value root, + SetVector<Operation *> *backwardSlice, + const BackwardSliceOptions &options) { if (Operation *definingOp = root.getDefiningOp()) { - getBackwardSlice(definingOp, backwardSlice, options); - return; + return getBackwardSlice(definingOp, backwardSlice, options); } Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp(); - getBackwardSlice(bbAargOwner, backwardSlice, options); + return getBackwardSlice(bbAargOwner, backwardSlice, options); } SetVector<Operation *> @@ -170,7 +177,9 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); + LogicalResult result = + getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); + assert(result.succeeded()); slice.insert_range(backwardSlice); // Compute and insert the forwardSlice starting from currentOp. @@ -193,7 +202,8 @@ static bool dependsOnCarriedVals(Value value, sliceOptions.filter = [&](Operation *op) { return !ancestorOp->isAncestor(op); }; - getBackwardSlice(value, &slice, sliceOptions); + LogicalResult result = getBackwardSlice(value, &slice, sliceOptions); + assert(result.succeeded()); // Check that none of the operands of the operations in the backward slice are // loop iteration arguments, and neither is the value itself. diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 8b16da3..0ec9ddc 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -317,7 +317,9 @@ getSliceContract(Operation *op, auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); + LogicalResult result = + getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); + assert(result.succeeded() && "expected a backward slice"); slice.insert_range(backwardSlice); // Compute and insert the forwardSlice starting from currentOp. diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index d33a17a..2c98bd3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -124,10 +124,13 @@ static void computeBackwardSlice(tensor::PadOp padOp, getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(), valuesDefinedAbove); for (Value v : valuesDefinedAbove) { - getBackwardSlice(v, &backwardSlice, sliceOptions); + LogicalResult result = getBackwardSlice(v, &backwardSlice, sliceOptions); + assert(result.succeeded() && "expected a backward slice"); } // Then, add the backward slice from padOp itself. - getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions); + LogicalResult result = + getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions); + assert(result.succeeded() && "expected a backward slice"); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 75dbe0b..1046f57 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -290,8 +290,10 @@ static void getPipelineStages( }); options.inclusive = true; for (Operation &op : forOp.getBody()->getOperations()) { - if (stage0Ops.contains(&op)) - getBackwardSlice(&op, &dependencies, options); + if (stage0Ops.contains(&op)) { + LogicalResult result = getBackwardSlice(&op, &dependencies, options); + assert(result.succeeded() && "expected a backward slice"); + } } for (Operation &op : forOp.getBody()->getOperations()) { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 719e2c6..9e3d3f8 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1772,7 +1772,8 @@ checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, }; llvm::SetVector<Operation *> slice; for (auto operand : consumerOp->getOperands()) { - getBackwardSlice(operand, &slice, options); + LogicalResult result = getBackwardSlice(operand, &slice, options); + assert(result.succeeded() && "expected a backward slice"); } if (!slice.empty()) { diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 4985d71..c136ff9 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1094,7 +1094,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); }; llvm::SetVector<Operation *> slice; - getBackwardSlice(op, &slice, options); + LogicalResult result = getBackwardSlice(op, &slice, options); + assert(result.succeeded() && "expected a backward slice"); // If the slice contains `insertionPoint` cannot move the dependencies. if (slice.contains(insertionPoint)) { @@ -1159,7 +1160,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, }; llvm::SetVector<Operation *> slice; for (auto value : prunedValues) { - getBackwardSlice(value, &slice, options); + LogicalResult result = getBackwardSlice(value, &slice, options); + assert(result.succeeded() && "expected a backward slice"); } // If the slice contains `insertionPoint` cannot move the dependencies. diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp index f26058f..145acd9 100644 --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -154,7 +154,9 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector<Operation *> backwardSlice; - getBackwardSlice(m.getMatchedOperation(), &backwardSlice); + LogicalResult result = + getBackwardSlice(m.getMatchedOperation(), &backwardSlice); + assert(result.succeeded() && "expected a backward slice"); outs << "\nmatched: " << *m.getMatchedOperation() << " backward static slice: "; for (auto *op : backwardSlice) diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp index e99d597..ad99be2 100644 --- a/mlir/test/lib/IR/TestSlicing.cpp +++ b/mlir/test/lib/IR/TestSlicing.cpp @@ -41,7 +41,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op, options.omitBlockArguments = omitBlockArguments; // TODO: Make this default. options.omitUsesFromAbove = false; - getBackwardSlice(op, &slice, options); + LogicalResult result = getBackwardSlice(op, &slice, options); + assert(result.succeeded() && "expected a backward slice"); for (Operation *slicedOp : slice) builder.clone(*slicedOp, mapper); builder.create<func::ReturnOp>(loc); |