aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWilliam Moses <gh@wsmoses.com>2025-05-22 14:13:32 -0500
committerGitHub <noreply@github.com>2025-05-22 14:13:32 -0500
commit6a8dde04a07287f837bbabeb93e23e47af366d3d (patch)
treef5b31b078f67eca8308668397b30355dbd2b4b47
parent4a6b1fb9dac5bb3a065b1e6e5192c4c7964af7e0 (diff)
downloadllvm-6a8dde04a07287f837bbabeb93e23e47af366d3d.zip
llvm-6a8dde04a07287f837bbabeb93e23e47af366d3d.tar.gz
llvm-6a8dde04a07287f837bbabeb93e23e47af366d3d.tar.bz2
[MLIR] Change getBackwardSlice to return a logicalresult rather than crash (#140961)
The current implementation of getBackwardSlice will crash if an operation in the dependency chain is defined by an operation with multiple regions or blocks. Crashing is bad (and forbids many analyses from using getBackwardSlice, as well as causing existing users of getBackwardSlice to fail for IR with this property). This PR instead causes the analysis to return a failure, rather than crash in the cases it cannot compute the full slice --------- Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com>
-rw-r--r--mlir/include/mlir/Analysis/SliceAnalysis.h12
-rw-r--r--mlir/include/mlir/Query/Matcher/SliceMatchers.h3
-rw-r--r--mlir/lib/Analysis/SliceAnalysis.cpp58
-rw-r--r--mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp7
-rw-r--r--mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp6
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp3
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp6
-rw-r--r--mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp4
-rw-r--r--mlir/test/lib/IR/TestSlicing.cpp3
10 files changed, 67 insertions, 39 deletions
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 3b731e8..d082d2d 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -138,13 +138,17 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
/// Assuming all local orders match the numbering order:
/// {1, 2, 5, 3, 4, 6}
///
-void getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options = {});
+/// This function returns whether the backwards slice was able to be
+/// successfully computed, and failure if it was unable to determine the slice.
+LogicalResult getBackwardSlice(Operation *op,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options = {});
/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
/// slices for the op defining or owning the value `root`.
-void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options = {});
+LogicalResult getBackwardSlice(Value root,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options = {});
/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `SetVector<Operation *>` which
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 1b0e4c3..40a39d2 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -112,7 +112,8 @@ bool BackwardSliceMatcher<Matcher>::matches(
}
return true;
};
- getBackwardSlice(rootOp, &backwardSlice, options);
+ LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
+ assert(result.succeeded() && "expected backward slice to succeed");
return options.inclusive ? backwardSlice.size() > 1
: backwardSlice.size() >= 1;
}
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 5aebb19..12b9d3a 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -80,25 +80,25 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
forwardSlice->insert(v.rbegin(), v.rend());
}
-static void getBackwardSliceImpl(Operation *op,
- SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options) {
+static LogicalResult getBackwardSliceImpl(Operation *op,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
- return;
+ return success();
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive backwardSlice in the current scope.
if (options.filter && !options.filter(op))
- return;
+ return success();
auto processValue = [&](Value value) {
if (auto *definingOp = value.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
- getBackwardSliceImpl(definingOp, backwardSlice, options);
+ return getBackwardSliceImpl(definingOp, backwardSlice, options);
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (options.omitBlockArguments)
- return;
+ return success();
Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
@@ -106,15 +106,17 @@ static void getBackwardSliceImpl(Operation *op,
// blocks of parentOp, which are not technically backward unless they flow
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
- assert(parentOp->getNumRegions() == 1 &&
- llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
- getBackwardSliceImpl(parentOp, backwardSlice, options);
+ if (parentOp->getNumRegions() == 1 &&
+ llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) {
+ return getBackwardSliceImpl(parentOp, backwardSlice, options);
+ }
}
- } else {
- llvm_unreachable("No definingOp and not a block argument.");
}
+ return failure();
};
+ bool succeeded = true;
+
if (!options.omitUsesFromAbove) {
llvm::for_each(op->getRegions(), [&](Region &region) {
// Walk this region recursively to collect the regions that descend from
@@ -125,36 +127,41 @@ static void getBackwardSliceImpl(Operation *op,
region.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
if (!descendents.contains(operand.get().getParentRegion()))
- processValue(operand.get());
+ if (!processValue(operand.get()).succeeded()) {
+ return WalkResult::interrupt();
+ }
}
+ return WalkResult::advance();
});
});
}
llvm::for_each(op->getOperands(), processValue);
backwardSlice->insert(op);
+ return success(succeeded);
}
-void mlir::getBackwardSlice(Operation *op,
- SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options) {
- getBackwardSliceImpl(op, backwardSlice, options);
+LogicalResult mlir::getBackwardSlice(Operation *op,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options) {
+ LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);
if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
backwardSlice->remove(op);
}
+ return result;
}
-void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options) {
+LogicalResult mlir::getBackwardSlice(Value root,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options) {
if (Operation *definingOp = root.getDefiningOp()) {
- getBackwardSlice(definingOp, backwardSlice, options);
- return;
+ return getBackwardSlice(definingOp, backwardSlice, options);
}
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
- getBackwardSlice(bbAargOwner, backwardSlice, options);
+ return getBackwardSlice(bbAargOwner, backwardSlice, options);
}
SetVector<Operation *>
@@ -170,7 +177,9 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
- getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
+ LogicalResult result =
+ getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
+ assert(result.succeeded());
slice.insert_range(backwardSlice);
// Compute and insert the forwardSlice starting from currentOp.
@@ -193,7 +202,8 @@ static bool dependsOnCarriedVals(Value value,
sliceOptions.filter = [&](Operation *op) {
return !ancestorOp->isAncestor(op);
};
- getBackwardSlice(value, &slice, sliceOptions);
+ LogicalResult result = getBackwardSlice(value, &slice, sliceOptions);
+ assert(result.succeeded());
// Check that none of the operands of the operations in the backward slice are
// loop iteration arguments, and neither is the value itself.
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 8b16da3..0ec9ddc 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -317,7 +317,9 @@ getSliceContract(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
- getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
+ LogicalResult result =
+ getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
+ assert(result.succeeded() && "expected a backward slice");
slice.insert_range(backwardSlice);
// Compute and insert the forwardSlice starting from currentOp.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index d33a17a..2c98bd3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -124,10 +124,13 @@ static void computeBackwardSlice(tensor::PadOp padOp,
getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
valuesDefinedAbove);
for (Value v : valuesDefinedAbove) {
- getBackwardSlice(v, &backwardSlice, sliceOptions);
+ LogicalResult result = getBackwardSlice(v, &backwardSlice, sliceOptions);
+ assert(result.succeeded() && "expected a backward slice");
}
// Then, add the backward slice from padOp itself.
- getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
+ LogicalResult result =
+ getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
+ assert(result.succeeded() && "expected a backward slice");
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 75dbe0b..1046f57 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -290,8 +290,10 @@ static void getPipelineStages(
});
options.inclusive = true;
for (Operation &op : forOp.getBody()->getOperations()) {
- if (stage0Ops.contains(&op))
- getBackwardSlice(&op, &dependencies, options);
+ if (stage0Ops.contains(&op)) {
+ LogicalResult result = getBackwardSlice(&op, &dependencies, options);
+ assert(result.succeeded() && "expected a backward slice");
+ }
}
for (Operation &op : forOp.getBody()->getOperations()) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 719e2c6..9e3d3f8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1772,7 +1772,8 @@ checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
};
llvm::SetVector<Operation *> slice;
for (auto operand : consumerOp->getOperands()) {
- getBackwardSlice(operand, &slice, options);
+ LogicalResult result = getBackwardSlice(operand, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
}
if (!slice.empty()) {
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4985d71..c136ff9 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1094,7 +1094,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
};
llvm::SetVector<Operation *> slice;
- getBackwardSlice(op, &slice, options);
+ LogicalResult result = getBackwardSlice(op, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
@@ -1159,7 +1160,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
- getBackwardSlice(value, &slice, options);
+ LogicalResult result = getBackwardSlice(value, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
}
// If the slice contains `insertionPoint` cannot move the dependencies.
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index f26058f..145acd9 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -154,7 +154,9 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
SetVector<Operation *> backwardSlice;
- getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
+ LogicalResult result =
+ getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
+ assert(result.succeeded() && "expected a backward slice");
outs << "\nmatched: " << *m.getMatchedOperation()
<< " backward static slice: ";
for (auto *op : backwardSlice)
diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
index e99d597..ad99be2 100644
--- a/mlir/test/lib/IR/TestSlicing.cpp
+++ b/mlir/test/lib/IR/TestSlicing.cpp
@@ -41,7 +41,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
options.omitBlockArguments = omitBlockArguments;
// TODO: Make this default.
options.omitUsesFromAbove = false;
- getBackwardSlice(op, &slice, options);
+ LogicalResult result = getBackwardSlice(op, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
for (Operation *slicedOp : slice)
builder.clone(*slicedOp, mapper);
builder.create<func::ReturnOp>(loc);