From b99d0b34400176cb9183113b96b245400caaf8d8 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Tue, 18 Jun 2024 09:07:29 -0700 Subject: [mlir][TilingInterface] Update `PartialReductionOpInterface` to get it more in line with `TilingInterface`. (#95460) The `TilingInterface` methods have return values that allow the interface implementation to return multiple operations, and also return tiled values explicitly. This is to avoid the assumption that the interface needs to return a single operation and this operations result are the expected tiled values. Make the `PartialReductionOpInterface::tileToPartialReduction` return `TilingResult` as well for the same reason. Similarly make the `PartialReductionOpInterface::mergeReductions` also return a list of generated operations and values to use as replacements. This is just a refactoring to allow for deprecation of `linalg::tileReductionUsingForall` with `scf::tileReductionUsingSCF` method. --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 4 +- .../Dialect/SCF/Transforms/TileUsingInterface.h | 6 ++- mlir/include/mlir/Interfaces/TilingInterface.h | 9 ++++ mlir/include/mlir/Interfaces/TilingInterface.td | 8 ++-- .../Linalg/TransformOps/LinalgTransformOps.cpp | 12 +++-- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 11 +++-- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 26 ++++++---- .../Dialect/SCF/Transforms/TileUsingInterface.cpp | 55 +++++++++++++--------- 8 files changed, 84 insertions(+), 47 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 308ce92..05e97be 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -873,9 +873,9 @@ tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, /// Transformation information returned after reduction tiling. struct ForallReductionTilingResult { /// The partial reduction tiled op generated. - Operation *parallelTiledOp; + SmallVector parallelTiledOps; /// The final reduction operation merging all the partial reductions. - Operation *mergeOp; + SmallVector mergeOps; /// Initial values used for partial reductions. SmallVector initialValues; /// The `scf.forall` operation that iterate over the tiles. diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index dac7911..6316f1d 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -261,13 +261,15 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); /// Transformation information returned after reduction tiling. struct SCFReductionTilingResult { /// The partial reduction tiled op generated. - Operation *parallelTiledOp; + SmallVector parallelTiledOps; /// The final reduction operation merging all the partial reductions. - Operation *mergeOp; + SmallVector mergeOps; /// Initial values used for reduction. SmallVector initialValues; /// The loop operations that iterate over the tiles. SmallVector loops; + /// The replacements to use for the results of the tiled operation. + SmallVector replacements; }; /// Method to tile a reduction and generate a parallel op within a serial loop. diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h index ca57049..2f51496 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -33,6 +33,15 @@ struct TilingResult { SmallVector tiledValues; }; +/// Container for the result of merge operation of tiling. +/// - `mergeOps` contains operations created during the merge. +/// - `replacements` contains the values that represents the result of the +/// merge. These are used as replacements for the original tiled operation. +struct MergeResult { + SmallVector mergeOps; + SmallVector replacements; +}; + } // namespace mlir /// Include the ODS generated interface header files. diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 8865aba..3f92786 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -360,7 +360,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { less or equal to the tile size. This is meant to be used with `mergeReductions` method which will combine the partial reductions. }], - /*retType=*/"Operation*", + /*retType=*/"FailureOr", /*methodName=*/"tileToPartialReduction", /*args=*/(ins "OpBuilder &":$b, @@ -371,7 +371,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { "ArrayRef":$reductionDims), /*methodBody=*/"", /*defaultImplementation=*/[{ - return nullptr; + return failure(); }] >, InterfaceMethod< @@ -380,7 +380,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { tiled along the reduction dimensions. This will only apply the reduction the operation. }], - /*retType=*/"Operation*", + /*retType=*/"FailureOr", /*methodName=*/"mergeReductions", /*args=*/(ins "OpBuilder &":$b, @@ -389,7 +389,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { "ArrayRef":$reductionDim), /*methodBody=*/"", /*defaultImplementation=*/[{ - return nullptr; + return failure(); }] > ]; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 9b31217..2807b3c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2525,8 +2525,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne( return emitDefaultSilenceableFailure(target); for (Value initValue : result->initialValues) results.push_back(initValue.getDefiningOp()); - results.push_back(result->parallelTiledOp); - results.push_back(result->mergeOp); + for (auto parallelTiledOp : result->parallelTiledOps) + results.push_back(parallelTiledOp); + for (auto mergeOp : result->mergeOps) + results.push_back(mergeOp); results.push_back(result->loops.front()); return DiagnosedSilenceableFailure::success(); } @@ -2577,8 +2579,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( } for (Value initValue : result->initialValues) results.push_back(initValue.getDefiningOp()); - results.push_back(result->parallelTiledOp); - results.push_back(result->mergeOp); + for (auto parallelTiledOp : result->parallelTiledOps) + results.push_back(parallelTiledOp); + for (auto mergeOp : result->mergeOps) + results.push_back(mergeOp); results.push_back(result->loops); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index a0a0e11..d8dee82 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -833,16 +833,19 @@ FailureOr linalg::tileReductionUsingForall( // 7. Merge the partial reductions. b.setInsertionPointAfter(forallOp); - Operation *mergeOp = + FailureOr mergeResult = op.mergeReductions(b, loc, forallOp->getResults(), reductionDim); - b.replaceOp(op, mergeOp->getResults()); + if (failed(mergeResult)) { + return failure(); + } + b.replaceOp(op, mergeResult->replacements); // 8. Return. ForallReductionTilingResult results; results.initialValues = initTensors; results.loops = forallOp; - results.parallelTiledOp = tiledOp; - results.mergeOp = mergeOp; + results.parallelTiledOps.push_back(tiledOp); + results.mergeOps.append(mergeResult->mergeOps); return results; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index c3ab3ce..b2a1e7c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -368,11 +368,11 @@ struct LinalgOpPartialReductionInterface return inits; } - Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, - ValueRange init, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef reductionDims) const { + FailureOr + tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, + ValueRange init, ArrayRef offsets, + ArrayRef sizes, + ArrayRef reductionDims) const { OpBuilder::InsertionGuard guard(b); auto linalgOp = cast(op); @@ -437,12 +437,15 @@ struct LinalgOpPartialReductionInterface IRMapping mapping; op->getRegion(0).cloneInto(&genericOp.getRegion(), genericOp.getRegion().begin(), mapping); - return genericOp.getOperation(); + return TilingResult{ + {genericOp.getOperation()}, + llvm::map_to_vector(genericOp->getResults(), + [](OpResult r) -> Value { return r; })}; } - Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc, - ValueRange partialReduce, - ArrayRef reductionDims) const { + FailureOr mergeReductions(Operation *op, OpBuilder &b, + Location loc, ValueRange partialReduce, + ArrayRef reductionDims) const { auto linalgOp = cast(op); // Step 1. Recover the dims that actually need to be merged from the @@ -493,7 +496,10 @@ struct LinalgOpPartialReductionInterface } b.create(loc, yieldedValues); }); - return reduction.getOperation(); + return MergeResult{ + {reduction.getOperation()}, + llvm::map_to_vector(reduction->getResults(), + [](OpResult r) -> Value { return r; })}; } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index f3d6b7a..35edd49 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, SmallVector &initTensors = maybeInitTensors.value(); // 3. Define the callback to use for generating the inner most tile loop body. - Operation *parallelOp = nullptr; + SmallVector parallelTiledOps; auto innerYieldTiledValuesFn = [&](RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange regionIterArgs, SmallVector &tiledResult, @@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, } // 4a. Clone the operation. - auto clonedOp = cast( - cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); + { + auto clonedOp = cast( + cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); + + // 4b. Tile the cloned operation. + FailureOr partialTilingResult = + clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets, + sizes, reductionDims); + if (failed(partialTilingResult)) { + return failure(); + } + std::swap(parallelTiledOps, partialTilingResult->tiledOps); + std::swap(tiledResult, partialTilingResult->tiledValues); - // 4b. Tile the cloned operation. - parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs, - offsets, sizes, reductionDims); - // 4c. Delete the cloned operation. - b.eraseOp(clonedOp); + // 4c. Delete the cloned operation. + b.eraseOp(clonedOp); + } - tiledResult.append(parallelOp->result_begin(), parallelOp->result_end()); // 4d. Compute the offsets and sizes needed to insert the result of the // tiled value back into destination before yielding the destination. - for (int resultIdx : llvm::seq(0, parallelOp->getNumResults())) { + for (auto result : tiledResult) { SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); resultOffsets.emplace_back(std::move(outOffsets)); SmallVector outSizes; for (size_t i = 0; i < offsets.size(); i++) { - outSizes.push_back( - tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i)); + outSizes.push_back(tensor::getMixedSize(b, loc, result, i)); } resultSizes.emplace_back(std::move(outSizes)); } @@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, // 5. Apply the merge reduction to combine all the partial values. b.setInsertionPointAfter(*loops.begin()); - Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); - b.replaceOp(op, mergeOp->getResults()); - - SCFReductionTilingResult results; - results.initialValues = initTensors; - results.loops = loops; - results.parallelTiledOp = parallelOp; - results.mergeOp = mergeOp; - return results; + FailureOr mergeResult = + op.mergeReductions(b, loc, replacements, reductionDims); + if (failed(mergeResult)) { + return failure(); + } + b.replaceOp(op, mergeResult->replacements); + + SCFReductionTilingResult reductionTilingResult; + std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps); + std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps); + std::swap(reductionTilingResult.initialValues, initTensors); + std::swap(reductionTilingResult.loops, loops); + std::swap(reductionTilingResult.replacements, mergeResult->replacements); + + return reductionTilingResult; } //===----------------------------------------------------------------------===// -- cgit v1.1