aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBoian Petkantchin <boian.petkantchin@amd.com>2024-03-07 17:05:44 -0800
committerGitHub <noreply@github.com>2024-03-07 17:05:44 -0800
commitfb582b6ace781ff6991775d6dcd4df98aa16698f (patch)
tree539ee352d2aab5951330e0105443c0f7082e723a
parent9d3bf9b639eafeded82c6be295031262735d1dac (diff)
downloadllvm-fb582b6ace781ff6991775d6dcd4df98aa16698f.zip
llvm-fb582b6ace781ff6991775d6dcd4df98aa16698f.tar.gz
llvm-fb582b6ace781ff6991775d6dcd4df98aa16698f.tar.bz2
[mlir] Implement Mesh's ShardingInterface for Linalg ops (#82284)
Allows linalg structured operations to be handled during spmdization and sharding propagation. There is only support for projected permutation indexing maps.
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h26
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h20
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td6
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td4
-rw-r--r--mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h18
-rw-r--r--mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h6
-rw-r--r--mlir/include/mlir/IR/Dialect.h8
-rw-r--r--mlir/include/mlir/InitAllDialects.h10
-rw-r--r--mlir/lib/Dialect/Linalg/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp7
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp24
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp352
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp8
-rw-r--r--mlir/lib/Dialect/Mesh/IR/MeshOps.cpp7
-rw-r--r--mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp89
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp13
-rw-r--r--mlir/test/Dialect/Linalg/mesh-spmdization.mlir165
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel4
19 files changed, 754 insertions, 19 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
new file mode 100644
index 0000000..a69751e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
@@ -0,0 +1,26 @@
+//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a common entry point for registering all external
+// interface implementations to the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
+#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerAllDialectInterfaceImplementations(DialectRegistry &registry);
+} // namespace linalg
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
new file mode 100644
index 0000000..c57501e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index fc2acc7..9d9b589 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
I32EnumAttrCase<"Sum", 1, "sum">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">,
+ I32EnumAttrCase<"Product", 4, "product">,
+ // Arithmetic mean.
+ I32EnumAttrCase<"Average", 5, "average">,
+ I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
+ I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
+ I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
I32EnumAttrCase<"Generic", 100, "generic">
]> {
let genSpecializedAttr = 0;
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index b9cd15e..8e1e475 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -353,6 +353,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$input, "StringRef":$mesh,
+ "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
+ ];
}
def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index ffc9b6f..ab4df2a 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -22,6 +22,24 @@ class SymbolTableCollection;
namespace mesh {
+// Retrieve the mesh axes corresponding to each operation loop iterator based
+// on the provided shardings for the op's operands and results.
+// Assumes that the indexingMaps are projected permutations.
+ShardingArray getMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<AffineMap> indexingMaps);
+
+bool isAtLeastOneReductionIteratorSharded(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+
+// Get the set of mesh axes that correspond to reduction loop iterators.
+SmallVector<MeshAxis> getReductionMeshAxes(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+
// Inserts a clone of the operation that has all ranked tensor
// arguments/results sharded.
void spmdizeTriviallyShardableOperation(
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index aeab289..be82e2a 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -13,6 +13,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class RewritePatternSet;
@@ -37,6 +38,11 @@ TypedValue<IndexType>
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder);
+// Get process linear index along the given mesh axes.
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ ImplicitLocOpBuilder &builder);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 50f6f6d..6c8a170 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -216,6 +216,14 @@ public:
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
}
+ // Declare the same interface for multiple types.
+ // Example:
+ // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
+ template <typename InterfaceT, typename... ConcreteT>
+ void declarePromisedInterfaces() {
+ (declarePromisedInterface<ConcreteT, InterfaceT>(), ...);
+ }
+
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error. `interfaceName` is an optional string that contains a
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 838bd03..21775e1 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -43,10 +43,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
@@ -157,10 +154,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
cf::registerBufferizableOpInterfaceExternalModels(registry);
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
- linalg::registerBufferizableOpInterfaceExternalModels(registry);
- linalg::registerSubsetOpInterfaceExternalModels(registry);
- linalg::registerTilingInterfaceExternalModels(registry);
- linalg::registerValueBoundsOpInterfaceExternalModels(registry);
+ linalg::registerAllDialectInterfaceImplementations(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index f0ac189..c187563 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRParser
+ MLIRShardingInterface
MLIRSideEffectInterfaces
MLIRSparseTensorDialect
MLIRSCFDialect
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
index 5069d43..027058d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -118,6 +119,12 @@ void mlir::linalg::LinalgDialect::initialize() {
>(namedStructuredOpRegionBuilders);
addInterfaces<LinalgInlinerInterface>();
+
+ declarePromisedInterface<GenericOp, mesh::ShardingInterface>();
+ declarePromisedInterfaces<mesh::ShardingInterface,
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >();
}
LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
new file mode 100644
index 0000000..281d9f2
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
@@ -0,0 +1,24 @@
+//===- AllInterfaces.cpp - ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+
+#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+
+void mlir::linalg::registerAllDialectInterfaceImplementations(
+ DialectRegistry &registry) {
+ registerBufferizableOpInterfaceExternalModels(registry);
+ registerMeshShardingInterfaceExternalModels(registry);
+ registerSubsetOpInterfaceExternalModels(registry);
+ registerTilingInterfaceExternalModels(registry);
+ registerValueBoundsOpInterfaceExternalModels(registry);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4f47e3b..513c54d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRLinalgTransforms
+ AllInterfaces.cpp
BubbleUpExtractSlice.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
@@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRIR
MLIRMemRefDialect
MLIRMemRefTransforms
+ MLIRMeshDialect
+ MLIRMeshTransforms
MLIRLinalgDialect
MLIRLinalgUtils
MLIRSCFDialect
MLIRSCFTransforms
MLIRSCFUtils
MLIRPass
+ MLIRShardingInterface
MLIRSubsetOpInterface
MLIRSparseTensorDialect
MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
new file mode 100644
index 0000000..7ac45dc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -0,0 +1,352 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+
+namespace mlir::linalg {
+
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+
+// Returns the corresponding mesh reduction kind for the given arith op.
+static ReductionKind getReductionKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ReductionKind>(op)
+ // Floating-point operations.
+ .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
+ .Case([](arith::MulFOp op) { return ReductionKind::Product; })
+ // TODO: handle maxnumf and minnumf.
+ .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
+ // Integer operations.
+ .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
+ .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
+ .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
+ .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+ // TODO: handle signless, signed and unsigned types properly.
+ // It is assumed that the element type of the collective operands and
+ // result drive the meaning of the reduction kind, whether it is signed
+ // or unsigned.
+ // The reduction op inside the linalg op may have different result type
+ // from the element type of the linalg op's result.
+ // Also signed and unsigned Arith dialect ops may accept signed, unsigned
+ // or signless operands.
+ // Maybe expand the reduction kinds.
+ .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
+ .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
+ .Case([](arith::MulIOp op) { return ReductionKind::Product; })
+ .Default([](Operation *op) { return ReductionKind::Generic; });
+}
+
+static std::optional<Operation *> getCombinerOp(LinalgOp op) {
+ SmallVector<Operation *> combinerOps;
+ Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
+ if (!reducedValue || combinerOps.size() != 1) {
+ return std::nullopt;
+ }
+
+ return combinerOps[0];
+}
+
+static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
+ std::optional<Operation *> reductionOp = getCombinerOp(op);
+ if (!reductionOp) {
+ return ReductionKind::Generic;
+ }
+ Type resultElementType =
+ llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
+ // TODO: handle case when result type of the reduction op does not match the
+ // element type of the result tensor.
+ // Would it makes sense at all?
+ assert(resultElementType == reductionOp.value()->getResult(0).getType());
+ return getReductionKind(reductionOp.value());
+}
+
+static MeshOp getMesh(Operation *op,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ SymbolTableCollection &symbolTable) {
+ for (MeshShardingAttr sharding : operandShardings) {
+ if (sharding) {
+ return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ }
+ }
+
+ for (MeshShardingAttr sharding : resultShardings) {
+ if (sharding) {
+ return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ }
+ }
+
+ assert(false);
+}
+
+// Choose the operand based on the current process index along the reduction
+// mesh axes.
+// We need to use the initial value only once to avoid including it in the
+// reduction multiple times.
+// In each process group only the leading process with linear index 0 would use
+// the original operand.
+// The other processes would use the reduction operation neutral tensor.
+static Value createDestinationPassingStyleInitOperand(
+ LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
+ MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+ Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
+ meshOp.getSymName(), reductionMeshAxes, builder);
+ Value zero = builder.create<arith::ConstantIndexOp>(0);
+ Value isLeadProcess = builder.create<arith::CmpIOp>(
+ builder.getI1Type(), arith::CmpIPredicate::eq,
+ processLinearIndexInReductionGroup, zero);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+ isLeadProcess, true, true);
+ // Then block.
+ {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+ builder.create<scf::YieldOp>(spmdizedOperand);
+ }
+
+ // Else block.
+ {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
+ SmallVector<OpFoldResult> shape =
+ tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+ PartialReductionOpInterface partialReductionIface =
+ llvm::cast<PartialReductionOpInterface>(op.getOperation());
+ FailureOr<Operation *> reductionNeutralTensorOp =
+ partialReductionIface.generateInitialTensorForPartialReduction(
+ builder, builder.getLoc(), shape, {});
+ assert(succeeded(reductionNeutralTensorOp));
+ builder.create<scf::YieldOp>(
+ reductionNeutralTensorOp.value()->getResult(0));
+ }
+ return ifOp.getResult(0);
+}
+
+// Create the DPS init operands for the spmdized Linalg op.
+// Return all the new spmdized operands.
+static SmallVector<Value> createDestinationPassingStyleInitOperands(
+ LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+ ImplicitLocOpBuilder &builder) {
+ // TODO: add support for multiple destination passing style initial value
+ // operands.
+ // PartialReductionOpInterface::generateInitialTensorForPartialReduction
+ // needs to also support multiple DPS initial operands.
+ SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+ auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
+ Value spmdizedInitOperand =
+ spmdizationMap.lookup(op->getOperands()[operandIdx]);
+ newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
+ op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+ return newOperands;
+}
+
+static void createAllReduceForResultWithoutPartialSharding(
+ Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
+ MeshShardingAttr resultSharding, ReductionKind reductionKind,
+ IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
+ SmallVector<MeshAxis> allReduceMeshAxes;
+ llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
+ [&resultSharding](MeshAxis axis) {
+ return !llvm::is_contained(resultSharding.getPartialAxes(),
+ axis);
+ });
+ if (allReduceMeshAxes.empty()) {
+ return;
+ }
+
+ Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
+ Value reducedValue = builder.create<mesh::AllReduceOp>(
+ spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
+ allReduceMeshAxes, reductionKind);
+ spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+}
+
+static void createAllReduceForResultsWithoutPartialShardings(
+ LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ ImplicitLocOpBuilder &builder) {
+ ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
+ for (auto [unshardedLinalgOpResult, resultSharding] :
+ llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
+ createAllReduceForResultWithoutPartialSharding(
+ unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
+ reductionKind, spmdizationMap, builder);
+ }
+}
+
+static void spmdizeLinalgOpWithShardedReduction(
+ LinalgOp op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
+ IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+ ImplicitLocOpBuilder &builder) {
+ MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
+ SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators);
+ SmallVector<Value> spmdizedLinalgOpOperands =
+ createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
+ reductionMeshAxes,
+ spmdizationMap, builder);
+ // We must not change the operand mappings of the original spmdizationMap as
+ // they are the mappings for the whole spmdization blob and may be used by
+ // others.
+ IRMapping internalSpmdizationMap;
+ for (auto [unshardedOperand, spmdizedOperand] :
+ llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
+ internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+ }
+ spmdizeTriviallyShardableOperation(
+ *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
+ internalSpmdizationMap, symbolTable, builder);
+ for (Value result : op->getResults()) {
+ spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+ }
+
+ // Handle partial shardings.
+ createAllReduceForResultsWithoutPartialShardings(
+ op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+}
+
+namespace {
+
+// ShardingInterface for ops that implement LinalgStructuredInterface.
+// The supported ops are only those where the indexing maps are projected
+// permutations.
+template <typename Op>
+struct StructuredOpShardingInterface
+ : public mesh::ShardingInterface::ExternalModel<
+ StructuredOpShardingInterface<Op>, Op> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+ SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
+
+ // Results must have the same indexing as destination passing style initial
+ // operands.
+ for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
+ res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
+ }
+
+ return res;
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+
+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+ bool allIndexingMapsAreProjectedPermutation =
+ llvm::all_of(indexingMaps, [](AffineMap map) {
+ return map.isProjectedPermutation();
+ });
+ if (!allIndexingMapsAreProjectedPermutation) {
+ // TODO: handle non-projected permutations.
+ return op->emitOpError()
+ << "supports indexing maps that are only projected permutation.";
+ }
+
+ SmallVector<utils::IteratorType> loopIteratorTypes =
+ linalgOp.getIteratorTypesArray();
+ ShardingArray meshAxisAssignmentForLoopIterators =
+ getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
+ loopIteratorTypes, indexingMaps);
+ if (mesh::isAtLeastOneReductionIteratorSharded(
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
+ spmdizeLinalgOpWithShardedReduction(
+ linalgOp, spmdizedOperands, operandShardings, resultShardings,
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
+ symbolTable, implicitLocBuilder);
+ } else {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
+ operandShardings, resultShardings,
+ spmdizationMap, symbolTable, builder);
+ }
+
+ return success();
+ }
+};
+
+} // namespace
+
+template <typename OpType>
+static void registerOne(MLIRContext *ctx) {
+ OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerAll(MLIRContext *ctx) {
+ (registerOne<OpTypes>(ctx), ...);
+}
+
+void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
+ DialectRegistry registry;
+ registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
+ tensor::TensorDialect>();
+ ctx->appendDialectRegistry(registry);
+ for (StringRef name : registry.getDialectNames())
+ ctx->getOrLoadDialect(name);
+
+ registerOne<linalg::GenericOp>(ctx);
+ registerAll<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(ctx);
+ });
+}
+
+} // namespace mlir::linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 8b3119f..bd870d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -275,14 +275,6 @@ struct LinalgOpPartialReductionInterface
ArrayRef<int64_t> oldShape =
linalgOp.getShape(linalgOp.getDpsInitOperand(0));
- // Extend tile size vector to the rank of the output tensor.
- SmallVector<Value> tileSizeVector =
- getValueOrCreateConstantIndexOp(b, loc, sizes);
- if (tileSizeVector.size() < oldShape.size()) {
- auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
- tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero);
- }
-
// Calculate the new shape, we insert the new dimensions based on the index
// of the reduction dimensions.
SmallVector<int64_t> newOutputShape;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 50163880..03f11ad 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -647,6 +647,13 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Value input, StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
+ build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
+ reduction);
+}
+
void AllReduceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "all_reduce");
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index fe3d7c4..9acee5a 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -539,8 +539,9 @@ static bool areValuesCompatibleWithFullReplicationShardings(
if (std::size(values) != std::size(shardings)) {
return false;
}
- return llvm::all_of(llvm::zip(std::forward<ValueRange>(values),
- std::forward<MeshShardingAttrRage>(shardings)),
+ return llvm::all_of(llvm::zip_equal(
+ std::forward<ValueRange>(values),
+ std::forward<MeshShardingAttrRage>(shardings)),
[](auto valueAndSharding) {
return isValueCompatibleWithFullReplicationSharding(
std::get<0>(valueAndSharding),
@@ -563,6 +564,88 @@ void mesh::spmdizeFullyReplicatedOperation(
builder.clone(op, spmdizationMap);
}
+static void updateMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+ SmallVector<std::optional<SmallVector<MeshAxis>>>
+ &meshAxesAssignmentForLoopIterators) {
+ AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
+ unsigned loopIteratorIdx = affineDimExpr.getPosition();
+ if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+ assert(llvm::equal(meshAxesAssignmentForTensorAxis,
+ *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+ } else {
+ meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
+ llvm::to_vector(meshAxesAssignmentForTensorAxis);
+ }
+}
+
+ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<AffineMap> indexingMaps) {
+ SmallVector<std::optional<SmallVector<MeshAxis>>>
+ meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+ SmallVector<MeshShardingAttr> operatorAndResultShardings;
+ operatorAndResultShardings.reserve(operandShardings.size() +
+ resultShardings.size());
+ llvm::append_range(operatorAndResultShardings, operandShardings);
+ for (auto [sharding, affineMap] :
+ llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
+ if (!sharding) {
+ continue;
+ }
+ for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
+ llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
+ updateMeshAxisAssignmentForLoopIterators(
+ meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
+ meshAxisAssignmentForLoopIterators);
+ }
+ // Missing trailing split axes means replication on those tensor dimensions.
+ for (unsigned i = sharding.getSplitAxes().size();
+ i < affineMap.getNumResults(); ++i) {
+ updateMeshAxisAssignmentForLoopIterators(
+ {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
+ }
+ }
+
+ ShardingArray res;
+ llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
+ [](std::optional<SmallVector<MeshAxis>> &axes) {
+ if (!axes) {
+ return SmallVector<MeshAxis>();
+ };
+ return std::move(*axes);
+ });
+ return res;
+}
+
+bool mesh::isAtLeastOneReductionIteratorSharded(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+ for (auto [loopIteratorType, meshAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ if (loopIteratorType == utils::IteratorType::reduction &&
+ !meshAxisAssignment.empty()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+SmallVector<MeshAxis> mesh::getReductionMeshAxes(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+ SmallVector<MeshAxis> meshAxes;
+ for (auto [loopIteratorType, meshAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ if (loopIteratorType == utils::IteratorType::reduction) {
+ llvm::append_range(meshAxes, meshAxisAssignment);
+ }
+ }
+ return meshAxes;
+}
+
void mesh::spmdizeTriviallyShardableOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
@@ -572,7 +655,7 @@ void mesh::spmdizeTriviallyShardableOperation(
Operation *newOp = builder.clone(op, spmdizationMap);
// Set the result types to the sharded counterparts.
for (auto [oldResult, newResult, sharding] :
- llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) {
+ llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
newResult.setType(shardType(newResult.getType(),
getMesh(&op, sharding.getMesh(), symbolTable),
sharding));
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index d59b911..cb13ee4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -208,4 +208,17 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
.cast<TypedValue<IndexType>>();
}
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ ImplicitLocOpBuilder &builder) {
+ ResultRange processInGroupMultiIndex =
+ builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
+ Operation::result_range processGroupShape =
+ builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
+ OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
+ llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+ llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+ return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
+}
+
} // namespace mlir::mesh
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
new file mode 100644
index 0000000..6d21def
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt \
+// RUN: --mesh-spmdization \
+// RUN: --test-constant-fold \
+// RUN: --split-input-file \
+// RUN: %s | FileCheck %s
+
+// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)>
+#map_identity_1d = affine_map<(d0) -> (d0)>
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor
+func.func @elementwise_static_1d_mesh_static_1d_tensor(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>,
+ %in1: tensor<2xi8>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>,
+ %in2: tensor<2xi8>,
+ // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1xi8>
+ %dps_out: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: %[[RES:.*]] = linalg.generic {
+ // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
+ // CHECK-SAME: iterator_types = ["parallel"]}
+ // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1xi8>, tensor<1xi8>)
+ // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1xi8>) {
+ %res = linalg.generic {
+ indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d],
+ iterator_types = ["parallel"]
+ } ins(%in1_shared2, %in2_shared2 : tensor<2xi8>, tensor<2xi8>)
+ outs(%dps_out_shared2 : tensor<2xi8>) {
+ ^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8):
+ %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
+ linalg.yield %res_scalar : i8
+ } -> tensor<2xi8>
+ %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %res_shared2 : tensor<2xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 4)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding
+func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>,
+ %in1: tensor<4x3xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>,
+ %in2: tensor<3x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1x8xi8>
+ %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<1x8xi8> {
+) -> tensor<4x8xi8> {
+ %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<4x3xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x3xi8>
+ %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[]]> : tensor<3x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<3x8xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<4x8xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK: %[[RES:.*]] = linalg.matmul
+ // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
+ // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
+ // CHECK-SAME: -> tensor<1x8xi8>
+ %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
+ outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+ %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK: return %[[RES]] : tensor<1x8xi8>
+ return %res_shared2 : tensor<4x8xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 3)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding
+func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
+ %in1: tensor<4x6xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
+ %in2: tensor<6x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
+ %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<4x8xi8> {
+) -> tensor<4x8xi8> {
+ %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
+ %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+ // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
+ // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
+ // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
+ // CHECK: } else {
+ // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
+ // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
+ // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+ // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
+ // CHECK: }
+ // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
+ // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+ // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
+ %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
+ outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+ %res_shared1 = mesh.shard %res to <@mesh_1d, [[]]> : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8>
+ return %res_shared2 : tensor<4x8xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 3)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result
+func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
+ %in1: tensor<4x6xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
+ %in2: tensor<6x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
+ %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<4x8xi8> {
+) -> tensor<4x8xi8> {
+ %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
+ %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+ // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
+ // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
+ // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
+ // CHECK: } else {
+ // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
+ // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
+ // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+ // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
+ // CHECK: }
+ // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
+ // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+ %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
+ outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+ %res_shared1 = mesh.shard %res to <@mesh_1d, [[]], partial = sum[0]> : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]], partial = sum[0]> annotate_for_users: tensor<4x8xi8>
+ // CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
+ return %res_shared2 : tensor<4x8xi8>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7a6bc2d..2cfe6184 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10841,6 +10841,7 @@ cc_library(
":MemRefDialect",
":Parser",
":SCFDialect",
+ ":MeshShardingInterface",
":SideEffectInterfaces",
":SparseTensorDialect",
":Support",
@@ -10994,10 +10995,13 @@ cc_library(
":MathDialect",
":MemRefDialect",
":MemRefTransforms",
+ ":MeshDialect",
+ ":MeshTransforms",
":Pass",
":SCFDialect",
":SCFTransforms",
":SCFUtils",
+ ":MeshShardingInterface",
":SparseTensorDialect",
":SubsetOpInterface",
":Support",