aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-06-24 21:35:58 +0200
committerGitHub <noreply@github.com>2024-06-24 21:35:58 +0200
commitf2d3d829b97a221c9ce3a3467a20ea51bb29ecbd (patch)
tree78b299ac332f7453707fabaca58568bcb1e8c716
parent09c0337a581dfd8f39df131786cfc7f675adf483 (diff)
downloadllvm-f2d3d829b97a221c9ce3a3467a20ea51bb29ecbd.zip
llvm-f2d3d829b97a221c9ce3a3467a20ea51bb29ecbd.tar.gz
llvm-f2d3d829b97a221c9ce3a3467a20ea51bb29ecbd.tar.bz2
[mlir][linalg][Transform] Fix use-after-free in `SplitOp::apply` (#96390)
Detected with ASAN. `Operation::getLoc()` was called after erasing the operation. Reverts 48cf6b6bbe7a22bfcd98f82dc7afd21c9decd22f, which attempted to fix the use-after-free. (But the use-after-free is still there when the `hasFailed` branch is taken.)
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp24
1 files changed, 13 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4ef27b1..4eb334f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
}
} else {
chunkSizes.resize(payload.size(),
- rewriter.getIndexAttr(getStaticChunkSizes()));
+ rewriter.getIndexAttr(getStaticChunkSizes()));
}
auto checkStructuredOpAndDimensions =
@@ -2327,7 +2327,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
if (getDimension() >= linalgOp.getNumLoops()) {
auto diag = emitSilenceableError() << "dimension " << getDimension()
- << " does not exist in target op";
+ << " does not exist in target op";
diag.attachNote(loc) << "target op";
return diag;
}
@@ -2335,10 +2335,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
};
auto checkFailureInSplitting =
- [&](bool hasFailed, Operation *op) -> DiagnosedSilenceableFailure {
+ [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
if (hasFailed) {
auto diag = emitDefiniteFailure() << "internal failure in splitting";
- diag.attachNote(op->getLoc()) << "target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
return DiagnosedSilenceableFailure::success();
@@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
break;
linalgOp = cast<LinalgOp>(target);
+ Location loc = target->getLoc();
rewriter.setInsertionPoint(linalgOp);
std::tie(head, tail) = linalg::splitOp(
@@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Propagate errors.
DiagnosedSilenceableFailure diag =
- checkFailureInSplitting(!head && !tail, target);
+ checkFailureInSplitting(!head && !tail, loc);
if (diag.isDefiniteFailure())
return diag;
@@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
Operation *noSecondPart = nullptr;
for (const auto &pair : llvm::zip(payload, chunkSizes)) {
Operation *target = std::get<0>(pair);
+ Location loc = target->getLoc();
LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
DiagnosedSilenceableFailure diag =
checkStructuredOpAndDimensions(linalgOp, target->getLoc());
@@ -2409,7 +2411,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Propagate errors.
DiagnosedSilenceableFailure diagSplit =
- checkFailureInSplitting(!first.back() && !second.back(), target);
+ checkFailureInSplitting(!first.back() && !second.back(), loc);
if (diagSplit.isDefiniteFailure())
return diag;
@@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
- return builder.getI64IntegerAttr(value);
- });
+ return builder.getI64IntegerAttr(value);
+ });
};
transformResults.setParams(cast<OpResult>(getTileSizes()),
getI64AttrsFromI64(spec->tileSizes));
@@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
}
auto getDefiningOps = [&](ArrayRef<Value> values) {
- return llvm::map_to_vector(values, [&](Value value) -> Operation * {
- return value.getDefiningOp();
- });
+ return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+ return value.getDefiningOp();
+ });
};
transformResults.set(cast<OpResult>(getTileSizes()),