aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBoian Petkantchin <boian.petkantchin@amd.com>2023-12-12 10:21:52 -0800
committerGitHub <noreply@github.com>2023-12-12 10:21:52 -0800
commit4b3446771f745bb5169354ad9027c0a1c9fca394 (patch)
treedfbeee3ab5c85d0bcca1365e5d04c785b2423541
parent8063622721d0b2b70e44f6e747eec54cdaec2e76 (diff)
downloadllvm-4b3446771f745bb5169354ad9027c0a1c9fca394.zip
llvm-4b3446771f745bb5169354ad9027c0a1c9fca394.tar.gz
llvm-4b3446771f745bb5169354ad9027c0a1c9fca394.tar.bz2
[mlir][mesh] Add endomorphism simplification for all-reduce (#73150)
Does transformations like all_reduce(x) + all_reduce(y) -> all_reduce(x + y) max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y)) when the all_reduce element-wise op is max. Added general rewrite pattern HomomorphismSimplification and EndomorphismSimplification that encapsulate the general algorithm. Made specialization for all-reduce with respect to addf, addi, minsi, maxsi, minimumf and maximumf in the Arithmetic dialect.
-rw-r--r--mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h110
-rw-r--r--mlir/include/mlir/Transforms/EndomorphismSimplification.h93
-rw-r--r--mlir/include/mlir/Transforms/HomomorphismSimplification.h188
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp39
-rw-r--r--mlir/test/Dialect/Mesh/simplifications.mlir167
-rw-r--r--mlir/test/lib/Dialect/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/Mesh/CMakeLists.txt13
-rw-r--r--mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp43
-rw-r--r--mlir/tools/mlir-opt/CMakeLists.txt1
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
11 files changed, 659 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
new file mode 100644
index 0000000..f70bdaa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -0,0 +1,110 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+// If we have an algebraic op like "+" and a summing all-reduce,
+// `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
+// `all_reduce_sum(x + y)`.
+//
+// Another example with `min`.
+// `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
+// `all_reduce_min(min(x, y))`.
+//
+// Works only with algebraic ops that have all their operands relevant
+// to the all-reduce endomorphism.
+// Will not work with some op `f(x, y, z)` where only `x` and `y` form
+// the algebraic structure.
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+ RewritePatternSet &patterns, Partial reduction) {
+ auto getEndomorphismOpOperand = [](Operation *op) {
+ auto allReduceOp = llvm::cast<AllReduceOp>(op);
+ return &allReduceOp.getInputMutable();
+ };
+ auto getEndomorphismOpResult = [](Operation *op) {
+ auto allReduceOp = llvm::cast<AllReduceOp>(op);
+ return allReduceOp->getResult(0);
+ };
+ auto getAlgebraicOpOperands = [](Operation *op,
+ SmallVector<OpOperand *> &operands) {
+ auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+ std::transform(algebraicOp->getOpOperands().begin(),
+ algebraicOp->getOpOperands().end(),
+ std::back_inserter(operands),
+ [](OpOperand &operand) { return &operand; });
+ };
+ auto getAlgebraicOpResult = [](Operation *op) {
+ auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+ return algebraicOp->getResult(0);
+ };
+ auto isEndomorphismOp = [reduction](Operation *op,
+ std::optional<Operation *> referenceOp) {
+ auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+ if (!allReduceOp ||
+ allReduceOp.getInput().getType().getElementType() !=
+ allReduceOp.getResult().getType().getElementType() ||
+ allReduceOp.getReduction() != reduction) {
+ return false;
+ }
+
+ // Dont't use simplify if the all-reduce is used other than by the
+ // algebraic op.
+ // TODO: maybe handle this by an additional pass that later reverses the
+ // simplification if there are other uses left other optimizations have
+ // been done.
+ if (!allReduceOp->hasOneUse()) {
+ return false;
+ }
+
+ if (!referenceOp) {
+ return true;
+ }
+
+ auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+ return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
+ allReduceOp.getInput().getType().getElementType() ==
+ refAllReduceOp.getInput().getType().getElementType();
+ };
+ auto isAlgebraicOp = [](Operation *op) {
+ return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
+ };
+
+ using ConcreteEndomorphismSimplification = EndomorphismSimplification<
+ std::decay_t<decltype(getEndomorphismOpOperand)>,
+ std::decay_t<decltype(getEndomorphismOpResult)>,
+ std::decay_t<decltype(getAlgebraicOpOperands)>,
+ std::decay_t<decltype(getAlgebraicOpResult)>,
+ std::decay_t<decltype(isEndomorphismOp)>,
+ std::decay_t<decltype(isAlgebraicOp)>>;
+ patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
+ std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
+ std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
+ std::move(isEndomorphismOp), std::move(isAlgebraicOp),
+ AlgebraicOp::getOperationName(), 1, patterns.getContext()));
+}
+
+void populateSimplificationPatterns(RewritePatternSet &patterns);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
diff --git a/mlir/include/mlir/Transforms/EndomorphismSimplification.h b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
new file mode 100644
index 0000000..1aa5da2
--- /dev/null
+++ b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
@@ -0,0 +1,93 @@
+//===- EndomorphismSimplification.h -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+
+#include "mlir/Transforms/HomomorphismSimplification.h"
+
+namespace mlir {
+
+namespace detail {
+struct CreateAlgebraicOpForEndomorphismSimplification {
+ Operation *operator()(Operation *op, IRMapping &operandsRemapping,
+ PatternRewriter &rewriter) const {
+ return rewriter.clone(*op, operandsRemapping);
+ }
+};
+} // namespace detail
+
+// If `f` is an endomorphism with respect to the algebraic structure induced by
+// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
+// `f(g(x1, x2, ..., xn))`.
+// `g` is the algebraic operation and `f` is the endomorphism.
+//
+// Functors:
+// ---------
+// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the endomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the endomorphism.
+//
+// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
+// Populates into the vector the operands relevant to the endomorphism.
+//
+// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+// Return the result relevant to the endomorphism.
+//
+// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an endomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible endomorphisms.
+//
+// `IsAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+template <typename GetEndomorphismOpOperandFn,
+ typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
+ typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
+ typename IsAlgebraicOpFn>
+struct EndomorphismSimplification
+ : HomomorphismSimplification<
+ GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
+ GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
+ GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
+ detail::CreateAlgebraicOpForEndomorphismSimplification> {
+ template <typename GetEndomorphismOpOperandFnArg,
+ typename GetEndomorphismOpResultFnArg,
+ typename GetAlgebraicOpOperandsFnArg,
+ typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
+ typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
+ EndomorphismSimplification(
+ GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
+ GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
+ GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
+ GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
+ IsEndomorphismOpFnArg &&isEndomorphismOp,
+ IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
+ : HomomorphismSimplification<
+ GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
+ GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
+ GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
+ detail::CreateAlgebraicOpForEndomorphismSimplification>(
+ std::forward<GetEndomorphismOpOperandFnArg>(
+ getEndomorphismOpOperand),
+ std::forward<GetEndomorphismOpResultFnArg>(getEndomorphismOpResult),
+ std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands),
+ std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
+ std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
+ std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp),
+ std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp),
+ detail::CreateAlgebraicOpForEndomorphismSimplification(),
+ std::forward<RewritePatternArgs>(args)...) {}
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
diff --git a/mlir/include/mlir/Transforms/HomomorphismSimplification.h b/mlir/include/mlir/Transforms/HomomorphismSimplification.h
new file mode 100644
index 0000000..d273260
--- /dev/null
+++ b/mlir/include/mlir/Transforms/HomomorphismSimplification.h
@@ -0,0 +1,188 @@
+//===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
+
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include <iterator>
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+namespace mlir {
+
+// If `h` is an homomorphism with respect to the source algebraic structure
+// induced by function `s` and the target algebraic structure induced by
+// function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
+// `h(t(x1, x2, ..., xn))`.
+//
+// Functors:
+// ---------
+// `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the homomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the homomorphism.
+//
+// `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
+// void` Populates into the vector the operands relevant to the homomorphism.
+//
+// `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+// Return the result of the source algebraic operation relevant to the
+// homomorphism.
+//
+// `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+// Return the result of the target algebraic operation relevant to the
+// homomorphism.
+//
+// `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an homomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible homomorphisms.
+//
+// `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+//
+// `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
+// PatternRewriter &rewriter) -> Operation*`
+template <typename GetHomomorphismOpOperandFn,
+ typename GetHomomorphismOpResultFn,
+ typename GetSourceAlgebraicOpOperandsFn,
+ typename GetSourceAlgebraicOpResultFn,
+ typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
+ typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
+struct HomomorphismSimplification : public RewritePattern {
+ template <typename GetHomomorphismOpOperandFnArg,
+ typename GetHomomorphismOpResultFnArg,
+ typename GetSourceAlgebraicOpOperandsFnArg,
+ typename GetSourceAlgebraicOpResultFnArg,
+ typename GetTargetAlgebraicOpResultFnArg,
+ typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
+ typename CreateTargetAlgebraicOpFnArg,
+ typename... RewritePatternArgs>
+ HomomorphismSimplification(
+ GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
+ GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
+ GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
+ GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
+ GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
+ IsHomomorphismOpFnArg &&isHomomorphismOp,
+ IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
+ CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
+ RewritePatternArgs &&...args)
+ : RewritePattern(std::forward<RewritePatternArgs>(args)...),
+ getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
+ getHomomorphismOpOperand)),
+ getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
+ getHomomorphismOpResult)),
+ getSourceAlgebraicOpOperands(
+ std::forward<GetSourceAlgebraicOpOperandsFnArg>(
+ getSourceAlgebraicOpOperands)),
+ getSourceAlgebraicOpResult(
+ std::forward<GetSourceAlgebraicOpResultFnArg>(
+ getSourceAlgebraicOpResult)),
+ getTargetAlgebraicOpResult(
+ std::forward<GetTargetAlgebraicOpResultFnArg>(
+ getTargetAlgebraicOpResult)),
+ isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
+ isSourceAlgebraicOp(
+ std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
+ createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
+ createTargetAlgebraicOpFn)) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<OpOperand *> algebraicOpOperands;
+ if (failed(matchOp(op, algebraicOpOperands))) {
+ return failure();
+ }
+ return rewriteOp(op, algebraicOpOperands, rewriter);
+ }
+
+private:
+ LogicalResult
+ matchOp(Operation *sourceAlgebraicOp,
+ SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
+ if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
+ return failure();
+ }
+ sourceAlgebraicOpOperands.clear();
+ getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
+ if (sourceAlgebraicOpOperands.empty()) {
+ return failure();
+ }
+
+ Operation *firstHomomorphismOp =
+ sourceAlgebraicOpOperands.front()->get().getDefiningOp();
+ if (!firstHomomorphismOp ||
+ !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
+ return failure();
+ }
+ OpResult firstHomomorphismOpResult =
+ getHomomorphismOpResult(firstHomomorphismOp);
+ if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
+ return failure();
+ }
+
+ for (auto operand : sourceAlgebraicOpOperands) {
+ Operation *homomorphismOp = operand->get().getDefiningOp();
+ if (!homomorphismOp ||
+ !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
+ return failure();
+ }
+ }
+ return success();
+ }
+
+ LogicalResult
+ rewriteOp(Operation *sourceAlgebraicOp,
+ const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
+ PatternRewriter &rewriter) const {
+ IRMapping irMapping;
+ for (auto operand : sourceAlgebraicOpOperands) {
+ Operation *homomorphismOp = operand->get().getDefiningOp();
+ irMapping.map(operand->get(),
+ getHomomorphismOpOperand(homomorphismOp)->get());
+ }
+ Operation *targetAlgebraicOp =
+ createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
+
+ irMapping.clear();
+ assert(!sourceAlgebraicOpOperands.empty());
+ Operation *firstHomomorphismOp =
+ sourceAlgebraicOpOperands[0]->get().getDefiningOp();
+ irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
+ getTargetAlgebraicOpResult(targetAlgebraicOp));
+ Operation *newHomomorphismOp =
+ rewriter.clone(*firstHomomorphismOp, irMapping);
+ rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
+ getHomomorphismOpResult(newHomomorphismOp));
+ return success();
+ }
+
+ GetHomomorphismOpOperandFn getHomomorphismOpOperand;
+ GetHomomorphismOpResultFn getHomomorphismOpResult;
+ GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
+ GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
+ GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
+ IsHomomorphismOpFn isHomomorphismOp;
+ IsSourceAlgebraicOpFn isSourceAlgebraicOp;
+ CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index bcf45c4..044b867 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMeshTransforms
+ Simplifications.cpp
ShardingPropagation.cpp
ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
MLIRShardingInterface
LINK_LIBS PUBLIC
+ MLIRArithDialect
MLIRFuncDialect
MLIRIR
MLIRMeshDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
new file mode 100644
index 0000000..643bd7b
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -0,0 +1,39 @@
+//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace mlir {
+namespace mesh {
+
+void populateSimplificationPatterns(RewritePatternSet &patterns) {
+ populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
+ patterns, Partial::Sum);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
+ patterns, Partial::Sum);
+
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
+ patterns, Partial::Min);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
+ patterns, Partial::Min);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
+ patterns, Partial::Min);
+
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
+ patterns, Partial::Max);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
+ patterns, Partial::Max);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
+ patterns, Partial::Max);
+
+ // TODO: add simplifications for all-gather and other collectives.
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
new file mode 100644
index 0000000..e716940
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 4x2)
+mesh.cluster @mesh1(rank = 1, dim_sizes = 4)
+
+// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
+// `all_reduce(x + y)`.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
+func.func @all_reduce_arith_addf_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result
+func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+ // CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]]
+ return %2, %2 : tensor<5xf32>, tensor<5xf32>
+}
+
+// Do not simplify if there is another use of one of the all-reduces.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
+func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
+ // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]]
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ALL_REDUCE_0_RES]], %[[ADD_RES]]
+ return %0, %2 : tensor<5xf32>, tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
+ %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
+func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
+func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf64> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf64>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf64>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf64>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf64>
+}
+
+// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
+// `all_reduce(min(x, y))`.
+// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
+func.func @all_reduce_arith_minimumf_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xf32> -> tensor<5xf32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
+ %2 = arith.minimumf %0, %1 : tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
+func.func @all_reduce_arith_minsi_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
+ %arg0: tensor<5xi32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
+ %arg1: tensor<5xi32>) -> tensor<5xi32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xi32> -> tensor<5xi32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xi32> -> tensor<5xi32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
+ %2 = arith.minsi %0, %1 : tensor<5xi32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : tensor<5xi32>
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 48bde69..30a17c2 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVM)
add_subdirectory(Math)
add_subdirectory(MemRef)
+add_subdirectory(Mesh)
add_subdirectory(NVGPU)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 0000000..16b50bb
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1,13 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRMeshTestSimplifications
+ TestSimplifications.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRMeshDialect
+ MLIRMeshTransforms
+ MLIRPass
+ MLIRRewrite
+ MLIRTransformUtils
+ )
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
new file mode 100644
index 0000000..93b1da5
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -0,0 +1,43 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMeshSimplificationsPass
+ : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass)
+
+ void runOnOperation() override;
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect, mesh::MeshDialect>();
+ }
+ StringRef getArgument() const final { return "test-mesh-simplifications"; }
+ StringRef getDescription() const final { return "Test mesh simplifications"; }
+};
+} // namespace
+
+void TestMeshSimplificationsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ mesh::populateSimplificationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshSimplificationsPass() {
+ PassRegistration<TestMeshSimplificationsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 88a0562..bc8eed1 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,6 +26,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRLoopLikeInterfaceTestPasses
MLIRMathTestPasses
MLIRMemRefTestPasses
+ MLIRMeshTestSimplifications
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3e3223b..c7cf1e5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
+void registerTestMeshSimplificationsPass();
void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
@@ -238,6 +239,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
+ mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestOneToNTypeConversionPass();
mlir::test::registerTestOpaqueLoc();