diff options
author | Boian Petkantchin <boian.petkantchin@amd.com> | 2024-01-09 13:42:56 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-09 13:42:56 -0800 |
commit | ab590377a371d8099829f77ab4e67c24f8740bd9 (patch) | |
tree | 83c9691be7757c5b3de40ecd55392082449bab00 | |
parent | cd101ab76bdee8d2583ae7b0dfbae9a745373731 (diff) | |
download | llvm-ab590377a371d8099829f77ab4e67c24f8740bd9.zip llvm-ab590377a371d8099829f77ab4e67c24f8740bd9.tar.gz llvm-ab590377a371d8099829f77ab4e67c24f8740bd9.tar.bz2 |
[mlir][mesh] Add folding of ClusterShapeOp (#77033)
If the mesh has static size on some of the requested axes, the result is
substituted with a constant.
-rw-r--r-- | mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h | 10 | ||||
-rw-r--r-- | mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp | 93 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/folding.mlir | 22 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Mesh/CMakeLists.txt | 2 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp | 8 | ||||
-rw-r--r-- | mlir/tools/mlir-opt/CMakeLists.txt | 2 |
6 files changed, 131 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h index f70bdaa..f438465 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h @@ -19,6 +19,9 @@ #include <utility> namespace mlir { + +class SymbolTableCollection; + namespace mesh { // If we have an algebraic op like "+" and a summing all-reduce, @@ -102,7 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns( AlgebraicOp::getOperationName(), 1, patterns.getContext())); } -void populateSimplificationPatterns(RewritePatternSet &patterns); +// It is invalid to change ops that declare symbols during the application of +// these patterns, because symbolTableCollection is used to cache them. +void populateSimplificationPatterns( + RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); +void populateFoldingPatterns(RewritePatternSet &patterns, + SymbolTableCollection &symbolTableCollection); } // namespace mesh } // namespace mlir diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp index 643bd7b..6262d3a 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp @@ -8,11 +8,23 @@ #include "mlir/Dialect/Mesh/Transforms/Simplifications.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include <iterator> +#include <numeric> +#include <utility> namespace mlir { namespace mesh { -void populateSimplificationPatterns(RewritePatternSet &patterns) { +void populateSimplificationPatterns( + RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>( patterns, Partial::Sum); populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>( @@ -33,6 +45,85 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) { patterns, Partial::Max); // TODO: add simplifications for all-gather and other collectives. + + populateFoldingPatterns(patterns, symbolTableCollection); +} + +namespace { + +// This folding can not be done with an operation's fold method or +// DialectFoldInterface, because it needs a SymbolTableCollection to cache the +// symbol tables. +// We can't use DialectFoldInterface since the cache may be invalidated by some +// pass changing the referenced ClusterOp ops. +struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> { + template <typename... OpRewritePatternArgs> + ClusterShapeFolder(SymbolTableCollection &symbolTableCollection, + OpRewritePatternArgs &&...opRewritePatternArgs) + : OpRewritePattern( + std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...), + symbolTableCollection(symbolTableCollection) {} + LogicalResult matchAndRewrite(ClusterShapeOp op, + PatternRewriter &rewriter) const override { + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + ClusterOp mesh = + symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>( + op.getOperation(), op.getMeshAttr()); + if (!mesh) { + return failure(); + } + ArrayRef<MeshAxis> opMeshAxes = op.getAxes(); + SmallVector<MeshAxis> opAxesIota; + if (opMeshAxes.empty()) { + opAxesIota.resize(mesh.getRank()); + std::iota(opAxesIota.begin(), opAxesIota.end(), 0); + opMeshAxes = opAxesIota; + } + if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) { + return ShapedType::isDynamic(mesh.getDimSizes()[axis]); + })) { + // All mesh dimensions are dynamic. Nothing to fold. + return failure(); + } + + SmallVector<Value> newResults(op->getResults().size()); + SmallVector<MeshAxis> newShapeOpMeshAxes; + SmallVector<size_t> newToOldResultsIndexMap; + + for (size_t i = 0; i < opMeshAxes.size(); ++i) { + auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]]; + if (ShapedType::isDynamic(meshAxisSize)) { + newToOldResultsIndexMap.push_back(i); + newShapeOpMeshAxes.push_back(opMeshAxes[i]); + } else { + // Fold static mesh axes. + newResults[i] = builder.create<arith::ConstantOp>( + builder.getIndexAttr(meshAxisSize)); + } + } + + // Leave only the dynamic mesh axes to be queried. + ClusterShapeOp newShapeOp = + builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes); + for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) { + newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i]; + } + + rewriter.replaceAllUsesWith(op.getResults(), newResults); + + return success(); + } + +private: + SymbolTableCollection &symbolTableCollection; +}; + +} // namespace + +void populateFoldingPatterns(RewritePatternSet &patterns, + SymbolTableCollection &symbolTableCollection) { + patterns.add<ClusterShapeFolder>(symbolTableCollection, + patterns.getContext()); } } // namespace mesh diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir new file mode 100644 index 0000000..dd64d74 --- /dev/null +++ b/mlir/test/Dialect/Mesh/folding.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s + +mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2) +mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3) + +// CHECK-LABEL: func.func @cluster_shape_op_folding +func.func @cluster_shape_op_folding() -> (index, index) { + // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index + // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index + %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index + // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]] + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh +func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) { + // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index + // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index + %0:2 = mesh.cluster_shape @mesh1 : index, index + // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]] + return %0#0, %0#1 : index, index +} diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt index f14d282..daff882 100644 --- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt @@ -1,5 +1,5 @@ # Exclude tests from libMLIR.so -add_mlir_library(MLIRMeshTestSimplifications +add_mlir_library(MLIRMeshTest TestReshardingSpmdization.cpp TestSimplifications.cpp diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp index 93b1da5..12a5fd5 100644 --- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Mesh/Transforms/Simplifications.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -30,8 +31,11 @@ struct TestMeshSimplificationsPass void TestMeshSimplificationsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - mesh::populateSimplificationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + SymbolTableCollection symbolTableCollection; + mesh::populateSimplificationPatterns(patterns, symbolTableCollection); + LogicalResult status = + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + assert(succeeded(status) && "Rewrite patters application did not converge."); } namespace mlir { diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index ce2f5bf..9ad5b32 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS) MLIRLoopLikeInterfaceTestPasses MLIRMathTestPasses MLIRMemRefTestPasses - MLIRMeshTestSimplifications + MLIRMeshTest MLIRNVGPUTestPasses MLIRSCFTestPasses MLIRShapeTestPasses |