aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com>2024-06-18 09:07:29 -0700
committerGitHub <noreply@github.com>2024-06-18 09:07:29 -0700
commitb99d0b34400176cb9183113b96b245400caaf8d8 (patch)
tree931c34cf060f42549b7b0bffc5ce019d8eda14b8
parentfcee0333bab6747ca34188f3a781f4fef900b7fe (diff)
downloadllvm-b99d0b34400176cb9183113b96b245400caaf8d8.zip
llvm-b99d0b34400176cb9183113b96b245400caaf8d8.tar.gz
llvm-b99d0b34400176cb9183113b96b245400caaf8d8.tar.bz2
[mlir][TilingInterface] Update `PartialReductionOpInterface` to get it more in line with `TilingInterface`. (#95460)
The `TilingInterface` methods have return values that allow the interface implementation to return multiple operations, and also return tiled values explicitly. This is to avoid the assumption that the interface needs to return a single operation and this operations result are the expected tiled values. Make the `PartialReductionOpInterface::tileToPartialReduction` return `TilingResult` as well for the same reason. Similarly make the `PartialReductionOpInterface::mergeReductions` also return a list of generated operations and values to use as replacements. This is just a refactoring to allow for deprecation of `linalg::tileReductionUsingForall` with `scf::tileReductionUsingSCF` method.
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h4
-rw-r--r--mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h6
-rw-r--r--mlir/include/mlir/Interfaces/TilingInterface.h9
-rw-r--r--mlir/include/mlir/Interfaces/TilingInterface.td8
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp12
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp26
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp55
8 files changed, 84 insertions, 47 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92..05e97be 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -873,9 +873,9 @@ tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
/// Transformation information returned after reduction tiling.
struct ForallReductionTilingResult {
/// The partial reduction tiled op generated.
- Operation *parallelTiledOp;
+ SmallVector<Operation *> parallelTiledOps;
/// The final reduction operation merging all the partial reductions.
- Operation *mergeOp;
+ SmallVector<Operation *> mergeOps;
/// Initial values used for partial reductions.
SmallVector<Value> initialValues;
/// The `scf.forall` operation that iterate over the tiles.
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac7911..6316f1d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -261,13 +261,15 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
/// Transformation information returned after reduction tiling.
struct SCFReductionTilingResult {
/// The partial reduction tiled op generated.
- Operation *parallelTiledOp;
+ SmallVector<Operation *> parallelTiledOps;
/// The final reduction operation merging all the partial reductions.
- Operation *mergeOp;
+ SmallVector<Operation *> mergeOps;
/// Initial values used for reduction.
SmallVector<Value> initialValues;
/// The loop operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
+ /// The replacements to use for the results of the tiled operation.
+ SmallVector<Value> replacements;
};
/// Method to tile a reduction and generate a parallel op within a serial loop.
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index ca57049..2f51496 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -33,6 +33,15 @@ struct TilingResult {
SmallVector<Value> tiledValues;
};
+/// Container for the result of merge operation of tiling.
+/// - `mergeOps` contains operations created during the merge.
+/// - `replacements` contains the values that represents the result of the
+/// merge. These are used as replacements for the original tiled operation.
+struct MergeResult {
+ SmallVector<Operation *> mergeOps;
+ SmallVector<Value> replacements;
+};
+
} // namespace mlir
/// Include the ODS generated interface header files.
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 8865aba..3f92786 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -360,7 +360,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
less or equal to the tile size. This is meant to be used with
`mergeReductions` method which will combine the partial reductions.
}],
- /*retType=*/"Operation*",
+ /*retType=*/"FailureOr<TilingResult>",
/*methodName=*/"tileToPartialReduction",
/*args=*/(ins
"OpBuilder &":$b,
@@ -371,7 +371,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
"ArrayRef<int>":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return nullptr;
+ return failure();
}]
>,
InterfaceMethod<
@@ -380,7 +380,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
tiled along the reduction dimensions. This will only apply the
reduction the operation.
}],
- /*retType=*/"Operation*",
+ /*retType=*/"FailureOr<MergeResult>",
/*methodName=*/"mergeReductions",
/*args=*/(ins
"OpBuilder &":$b,
@@ -389,7 +389,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
"ArrayRef<int>":$reductionDim),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return nullptr;
+ return failure();
}]
>
];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b31217..2807b3c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2525,8 +2525,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
return emitDefaultSilenceableFailure(target);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
- results.push_back(result->parallelTiledOp);
- results.push_back(result->mergeOp);
+ for (auto parallelTiledOp : result->parallelTiledOps)
+ results.push_back(parallelTiledOp);
+ for (auto mergeOp : result->mergeOps)
+ results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
}
@@ -2577,8 +2579,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
}
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
- results.push_back(result->parallelTiledOp);
- results.push_back(result->mergeOp);
+ for (auto parallelTiledOp : result->parallelTiledOps)
+ results.push_back(parallelTiledOp);
+ for (auto mergeOp : result->mergeOps)
+ results.push_back(mergeOp);
results.push_back(result->loops);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index a0a0e11..d8dee82 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -833,16 +833,19 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 7. Merge the partial reductions.
b.setInsertionPointAfter(forallOp);
- Operation *mergeOp =
+ FailureOr<MergeResult> mergeResult =
op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
- b.replaceOp(op, mergeOp->getResults());
+ if (failed(mergeResult)) {
+ return failure();
+ }
+ b.replaceOp(op, mergeResult->replacements);
// 8. Return.
ForallReductionTilingResult results;
results.initialValues = initTensors;
results.loops = forallOp;
- results.parallelTiledOp = tiledOp;
- results.mergeOp = mergeOp;
+ results.parallelTiledOps.push_back(tiledOp);
+ results.mergeOps.append(mergeResult->mergeOps);
return results;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3ce..b2a1e7c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -368,11 +368,11 @@ struct LinalgOpPartialReductionInterface
return inits;
}
- Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
- ValueRange init,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<int> reductionDims) const {
+ FailureOr<TilingResult>
+ tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
+ ValueRange init, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<int> reductionDims) const {
OpBuilder::InsertionGuard guard(b);
auto linalgOp = cast<LinalgOp>(op);
@@ -437,12 +437,15 @@ struct LinalgOpPartialReductionInterface
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
- return genericOp.getOperation();
+ return TilingResult{
+ {genericOp.getOperation()},
+ llvm::map_to_vector(genericOp->getResults(),
+ [](OpResult r) -> Value { return r; })};
}
- Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
- ValueRange partialReduce,
- ArrayRef<int> reductionDims) const {
+ FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
+ Location loc, ValueRange partialReduce,
+ ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
// Step 1. Recover the dims that actually need to be merged from the
@@ -493,7 +496,10 @@ struct LinalgOpPartialReductionInterface
}
b.create<linalg::YieldOp>(loc, yieldedValues);
});
- return reduction.getOperation();
+ return MergeResult{
+ {reduction.getOperation()},
+ llvm::map_to_vector(reduction->getResults(),
+ [](OpResult r) -> Value { return r; })};
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a..35edd49 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
SmallVector<Value> &initTensors = maybeInitTensors.value();
// 3. Define the callback to use for generating the inner most tile loop body.
- Operation *parallelOp = nullptr;
+ SmallVector<Operation *> parallelTiledOps;
auto innerYieldTiledValuesFn =
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
@@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
}
// 4a. Clone the operation.
- auto clonedOp = cast<PartialReductionOpInterface>(
- cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+ {
+ auto clonedOp = cast<PartialReductionOpInterface>(
+ cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+
+ // 4b. Tile the cloned operation.
+ FailureOr<TilingResult> partialTilingResult =
+ clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
+ sizes, reductionDims);
+ if (failed(partialTilingResult)) {
+ return failure();
+ }
+ std::swap(parallelTiledOps, partialTilingResult->tiledOps);
+ std::swap(tiledResult, partialTilingResult->tiledValues);
- // 4b. Tile the cloned operation.
- parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
- offsets, sizes, reductionDims);
- // 4c. Delete the cloned operation.
- b.eraseOp(clonedOp);
+ // 4c. Delete the cloned operation.
+ b.eraseOp(clonedOp);
+ }
- tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
// 4d. Compute the offsets and sizes needed to insert the result of the
// tiled value back into destination before yielding the destination.
- for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
+ for (auto result : tiledResult) {
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
resultOffsets.emplace_back(std::move(outOffsets));
SmallVector<OpFoldResult> outSizes;
for (size_t i = 0; i < offsets.size(); i++) {
- outSizes.push_back(
- tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
+ outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
}
resultSizes.emplace_back(std::move(outSizes));
}
@@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// 5. Apply the merge reduction to combine all the partial values.
b.setInsertionPointAfter(*loops.begin());
- Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
- b.replaceOp(op, mergeOp->getResults());
-
- SCFReductionTilingResult results;
- results.initialValues = initTensors;
- results.loops = loops;
- results.parallelTiledOp = parallelOp;
- results.mergeOp = mergeOp;
- return results;
+ FailureOr<MergeResult> mergeResult =
+ op.mergeReductions(b, loc, replacements, reductionDims);
+ if (failed(mergeResult)) {
+ return failure();
+ }
+ b.replaceOp(op, mergeResult->replacements);
+
+ SCFReductionTilingResult reductionTilingResult;
+ std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
+ std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
+ std::swap(reductionTilingResult.initialValues, initTensors);
+ std::swap(reductionTilingResult.loops, loops);
+ std::swap(reductionTilingResult.replacements, mergeResult->replacements);
+
+ return reductionTilingResult;
}
//===----------------------------------------------------------------------===//