diff options
author | MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> | 2024-06-18 09:07:29 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-18 09:07:29 -0700 |
commit | b99d0b34400176cb9183113b96b245400caaf8d8 (patch) | |
tree | 931c34cf060f42549b7b0bffc5ce019d8eda14b8 | |
parent | fcee0333bab6747ca34188f3a781f4fef900b7fe (diff) | |
download | llvm-b99d0b34400176cb9183113b96b245400caaf8d8.zip llvm-b99d0b34400176cb9183113b96b245400caaf8d8.tar.gz llvm-b99d0b34400176cb9183113b96b245400caaf8d8.tar.bz2 |
[mlir][TilingInterface] Update `PartialReductionOpInterface` to get it more in line with `TilingInterface`. (#95460)
The `TilingInterface` methods have return values that allow the
interface implementation to return multiple operations, and also return
tiled values explicitly. This is to avoid the assumption that the
interface needs to return a single operation and this operations result
are the expected tiled values. Make the
`PartialReductionOpInterface::tileToPartialReduction` return
`TilingResult` as well for the same reason.
Similarly make the `PartialReductionOpInterface::mergeReductions` also
return a list of generated operations and values to use as replacements.
This is just a refactoring to allow for deprecation of
`linalg::tileReductionUsingForall` with `scf::tileReductionUsingSCF`
method.
8 files changed, 84 insertions, 47 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 308ce92..05e97be 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -873,9 +873,9 @@ tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, /// Transformation information returned after reduction tiling. struct ForallReductionTilingResult { /// The partial reduction tiled op generated. - Operation *parallelTiledOp; + SmallVector<Operation *> parallelTiledOps; /// The final reduction operation merging all the partial reductions. - Operation *mergeOp; + SmallVector<Operation *> mergeOps; /// Initial values used for partial reductions. SmallVector<Value> initialValues; /// The `scf.forall` operation that iterate over the tiles. diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index dac7911..6316f1d 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -261,13 +261,15 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); /// Transformation information returned after reduction tiling. struct SCFReductionTilingResult { /// The partial reduction tiled op generated. - Operation *parallelTiledOp; + SmallVector<Operation *> parallelTiledOps; /// The final reduction operation merging all the partial reductions. - Operation *mergeOp; + SmallVector<Operation *> mergeOps; /// Initial values used for reduction. SmallVector<Value> initialValues; /// The loop operations that iterate over the tiles. SmallVector<LoopLikeOpInterface> loops; + /// The replacements to use for the results of the tiled operation. + SmallVector<Value> replacements; }; /// Method to tile a reduction and generate a parallel op within a serial loop. diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h index ca57049..2f51496 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -33,6 +33,15 @@ struct TilingResult { SmallVector<Value> tiledValues; }; +/// Container for the result of merge operation of tiling. +/// - `mergeOps` contains operations created during the merge. +/// - `replacements` contains the values that represents the result of the +/// merge. These are used as replacements for the original tiled operation. +struct MergeResult { + SmallVector<Operation *> mergeOps; + SmallVector<Value> replacements; +}; + } // namespace mlir /// Include the ODS generated interface header files. diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 8865aba..3f92786 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -360,7 +360,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { less or equal to the tile size. This is meant to be used with `mergeReductions` method which will combine the partial reductions. }], - /*retType=*/"Operation*", + /*retType=*/"FailureOr<TilingResult>", /*methodName=*/"tileToPartialReduction", /*args=*/(ins "OpBuilder &":$b, @@ -371,7 +371,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { "ArrayRef<int>":$reductionDims), /*methodBody=*/"", /*defaultImplementation=*/[{ - return nullptr; + return failure(); }] >, InterfaceMethod< @@ -380,7 +380,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { tiled along the reduction dimensions. This will only apply the reduction the operation. }], - /*retType=*/"Operation*", + /*retType=*/"FailureOr<MergeResult>", /*methodName=*/"mergeReductions", /*args=*/(ins "OpBuilder &":$b, @@ -389,7 +389,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { "ArrayRef<int>":$reductionDim), /*methodBody=*/"", /*defaultImplementation=*/[{ - return nullptr; + return failure(); }] > ]; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 9b31217..2807b3c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2525,8 +2525,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne( return emitDefaultSilenceableFailure(target); for (Value initValue : result->initialValues) results.push_back(initValue.getDefiningOp()); - results.push_back(result->parallelTiledOp); - results.push_back(result->mergeOp); + for (auto parallelTiledOp : result->parallelTiledOps) + results.push_back(parallelTiledOp); + for (auto mergeOp : result->mergeOps) + results.push_back(mergeOp); results.push_back(result->loops.front()); return DiagnosedSilenceableFailure::success(); } @@ -2577,8 +2579,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( } for (Value initValue : result->initialValues) results.push_back(initValue.getDefiningOp()); - results.push_back(result->parallelTiledOp); - results.push_back(result->mergeOp); + for (auto parallelTiledOp : result->parallelTiledOps) + results.push_back(parallelTiledOp); + for (auto mergeOp : result->mergeOps) + results.push_back(mergeOp); results.push_back(result->loops); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index a0a0e11..d8dee82 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -833,16 +833,19 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall( // 7. Merge the partial reductions. b.setInsertionPointAfter(forallOp); - Operation *mergeOp = + FailureOr<MergeResult> mergeResult = op.mergeReductions(b, loc, forallOp->getResults(), reductionDim); - b.replaceOp(op, mergeOp->getResults()); + if (failed(mergeResult)) { + return failure(); + } + b.replaceOp(op, mergeResult->replacements); // 8. Return. ForallReductionTilingResult results; results.initialValues = initTensors; results.loops = forallOp; - results.parallelTiledOp = tiledOp; - results.mergeOp = mergeOp; + results.parallelTiledOps.push_back(tiledOp); + results.mergeOps.append(mergeResult->mergeOps); return results; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index c3ab3ce..b2a1e7c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -368,11 +368,11 @@ struct LinalgOpPartialReductionInterface return inits; } - Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, - ValueRange init, - ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes, - ArrayRef<int> reductionDims) const { + FailureOr<TilingResult> + tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, + ValueRange init, ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes, + ArrayRef<int> reductionDims) const { OpBuilder::InsertionGuard guard(b); auto linalgOp = cast<LinalgOp>(op); @@ -437,12 +437,15 @@ struct LinalgOpPartialReductionInterface IRMapping mapping; op->getRegion(0).cloneInto(&genericOp.getRegion(), genericOp.getRegion().begin(), mapping); - return genericOp.getOperation(); + return TilingResult{ + {genericOp.getOperation()}, + llvm::map_to_vector(genericOp->getResults(), + [](OpResult r) -> Value { return r; })}; } - Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc, - ValueRange partialReduce, - ArrayRef<int> reductionDims) const { + FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b, + Location loc, ValueRange partialReduce, + ArrayRef<int> reductionDims) const { auto linalgOp = cast<LinalgOp>(op); // Step 1. Recover the dims that actually need to be merged from the @@ -493,7 +496,10 @@ struct LinalgOpPartialReductionInterface } b.create<linalg::YieldOp>(loc, yieldedValues); }); - return reduction.getOperation(); + return MergeResult{ + {reduction.getOperation()}, + llvm::map_to_vector(reduction->getResults(), + [](OpResult r) -> Value { return r; })}; } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index f3d6b7a..35edd49 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, SmallVector<Value> &initTensors = maybeInitTensors.value(); // 3. Define the callback to use for generating the inner most tile loop body. - Operation *parallelOp = nullptr; + SmallVector<Operation *> parallelTiledOps; auto innerYieldTiledValuesFn = [&](RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange regionIterArgs, SmallVector<Value> &tiledResult, @@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, } // 4a. Clone the operation. - auto clonedOp = cast<PartialReductionOpInterface>( - cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); + { + auto clonedOp = cast<PartialReductionOpInterface>( + cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); + + // 4b. Tile the cloned operation. + FailureOr<TilingResult> partialTilingResult = + clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets, + sizes, reductionDims); + if (failed(partialTilingResult)) { + return failure(); + } + std::swap(parallelTiledOps, partialTilingResult->tiledOps); + std::swap(tiledResult, partialTilingResult->tiledValues); - // 4b. Tile the cloned operation. - parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs, - offsets, sizes, reductionDims); - // 4c. Delete the cloned operation. - b.eraseOp(clonedOp); + // 4c. Delete the cloned operation. + b.eraseOp(clonedOp); + } - tiledResult.append(parallelOp->result_begin(), parallelOp->result_end()); // 4d. Compute the offsets and sizes needed to insert the result of the // tiled value back into destination before yielding the destination. - for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) { + for (auto result : tiledResult) { SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); resultOffsets.emplace_back(std::move(outOffsets)); SmallVector<OpFoldResult> outSizes; for (size_t i = 0; i < offsets.size(); i++) { - outSizes.push_back( - tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i)); + outSizes.push_back(tensor::getMixedSize(b, loc, result, i)); } resultSizes.emplace_back(std::move(outSizes)); } @@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, // 5. Apply the merge reduction to combine all the partial values. b.setInsertionPointAfter(*loops.begin()); - Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); - b.replaceOp(op, mergeOp->getResults()); - - SCFReductionTilingResult results; - results.initialValues = initTensors; - results.loops = loops; - results.parallelTiledOp = parallelOp; - results.mergeOp = mergeOp; - return results; + FailureOr<MergeResult> mergeResult = + op.mergeReductions(b, loc, replacements, reductionDims); + if (failed(mergeResult)) { + return failure(); + } + b.replaceOp(op, mergeResult->replacements); + + SCFReductionTilingResult reductionTilingResult; + std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps); + std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps); + std::swap(reductionTilingResult.initialValues, initTensors); + std::swap(reductionTilingResult.loops, loops); + std::swap(reductionTilingResult.replacements, mergeResult->replacements); + + return reductionTilingResult; } //===----------------------------------------------------------------------===// |