diff options
author | Boian Petkantchin <boian.petkantchin@amd.com> | 2024-03-07 17:05:44 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-07 17:05:44 -0800 |
commit | fb582b6ace781ff6991775d6dcd4df98aa16698f (patch) | |
tree | 539ee352d2aab5951330e0105443c0f7082e723a | |
parent | 9d3bf9b639eafeded82c6be295031262735d1dac (diff) | |
download | llvm-fb582b6ace781ff6991775d6dcd4df98aa16698f.zip llvm-fb582b6ace781ff6991775d6dcd4df98aa16698f.tar.gz llvm-fb582b6ace781ff6991775d6dcd4df98aa16698f.tar.bz2 |
[mlir] Implement Mesh's ShardingInterface for Linalg ops (#82284)
Allows linalg structured operations to be handled during spmdization and
sharding propagation.
There is only support for projected permutation indexing maps.
19 files changed, 754 insertions, 19 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h new file mode 100644 index 0000000..a69751e --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h @@ -0,0 +1,26 @@ +//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a common entry point for registering all external +// interface implementations to the linalg dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H +#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerAllDialectInterfaceImplementations(DialectRegistry ®istry); +} // namespace linalg + +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h new file mode 100644 index 0000000..c57501e --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- MeshShardingInterfaceImpl.h ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H +#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td index fc2acc7..9d9b589 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td @@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind", I32EnumAttrCase<"Sum", 1, "sum">, I32EnumAttrCase<"Max", 2, "max">, I32EnumAttrCase<"Min", 3, "min">, + I32EnumAttrCase<"Product", 4, "product">, + // Arithmetic mean. + I32EnumAttrCase<"Average", 5, "average">, + I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">, + I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">, + I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">, I32EnumAttrCase<"Generic", 100, "generic"> ]> { let genSpecializedAttr = 0; diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index b9cd15e..8e1e475 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -353,6 +353,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ attr-dict `:` type($input) `->` type($result) }]; let hasCanonicalizer = 1; + let builders = [ + OpBuilder<(ins "Value":$input, "StringRef":$mesh, + "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)> + ]; } def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [ diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h index ffc9b6f..ab4df2a 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h @@ -22,6 +22,24 @@ class SymbolTableCollection; namespace mesh { +// Retrieve the mesh axes corresponding to each operation loop iterator based +// on the provided shardings for the op's operands and results. +// Assumes that the indexingMaps are projected permutations. +ShardingArray getMeshAxisAssignmentForLoopIterators( + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<AffineMap> indexingMaps); + +bool isAtLeastOneReductionIteratorSharded( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); + +// Get the set of mesh axes that correspond to reduction loop iterators. +SmallVector<MeshAxis> getReductionMeshAxes( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); + // Inserts a clone of the operation that has all ranked tensor // arguments/results sharded. void spmdizeTriviallyShardableOperation( diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h index aeab289..be82e2a 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h @@ -13,6 +13,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" namespace mlir { class RewritePatternSet; @@ -37,6 +38,11 @@ TypedValue<IndexType> createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes, ImplicitLocOpBuilder &builder); +// Get process linear index along the given mesh axes. +TypedValue<IndexType> createProcessLinearIndex(StringRef mesh, + ArrayRef<MeshAxis> meshAxes, + ImplicitLocOpBuilder &builder); + } // namespace mesh } // namespace mlir diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 50f6f6d..6c8a170 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -216,6 +216,14 @@ public: {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()}); } + // Declare the same interface for multiple types. + // Example: + // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>() + template <typename InterfaceT, typename... ConcreteT> + void declarePromisedInterfaces() { + (declarePromisedInterface<ConcreteT, InterfaceT>(), ...); + } + /// Checks if the given interface, which is attempting to be used, is a /// promised interface of this dialect that has yet to be implemented. If so, /// emits a fatal error. `interfaceName` is an optional string that contains a diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 838bd03..21775e1 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -43,10 +43,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MPI/IR/MPI.h" @@ -157,10 +154,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { cf::registerBufferizableOpInterfaceExternalModels(registry); cf::registerBufferDeallocationOpInterfaceExternalModels(registry); gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); - linalg::registerBufferizableOpInterfaceExternalModels(registry); - linalg::registerSubsetOpInterfaceExternalModels(registry); - linalg::registerTilingInterfaceExternalModels(registry); - linalg::registerValueBoundsOpInterfaceExternalModels(registry); + linalg::registerAllDialectInterfaceImplementations(registry); memref::registerAllocationOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); memref::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt index f0ac189..c187563 100644 --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgDialect MLIRInferTypeOpInterface MLIRIR MLIRParser + MLIRShardingInterface MLIRSideEffectInterfaces MLIRSparseTensorDialect MLIRSCFDialect diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp index 5069d43..027058d 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -118,6 +119,12 @@ void mlir::linalg::LinalgDialect::initialize() { >(namedStructuredOpRegionBuilders); addInterfaces<LinalgInlinerInterface>(); + + declarePromisedInterface<GenericOp, mesh::ShardingInterface>(); + declarePromisedInterfaces<mesh::ShardingInterface, +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(); } LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp new file mode 100644 index 0000000..281d9f2 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp @@ -0,0 +1,24 @@ +//===- AllInterfaces.cpp - ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" + +#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" + +void mlir::linalg::registerAllDialectInterfaceImplementations( + DialectRegistry ®istry) { + registerBufferizableOpInterfaceExternalModels(registry); + registerMeshShardingInterfaceExternalModels(registry); + registerSubsetOpInterfaceExternalModels(registry); + registerTilingInterfaceExternalModels(registry); + registerValueBoundsOpInterfaceExternalModels(registry); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 4f47e3b..513c54d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRLinalgTransforms + AllInterfaces.cpp BubbleUpExtractSlice.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp @@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms InlineScalarOperands.cpp Interchange.cpp Loops.cpp + MeshShardingInterfaceImpl.cpp NamedOpConversions.cpp Padding.cpp Promotion.cpp @@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRIR MLIRMemRefDialect MLIRMemRefTransforms + MLIRMeshDialect + MLIRMeshTransforms MLIRLinalgDialect MLIRLinalgUtils MLIRSCFDialect MLIRSCFTransforms MLIRSCFUtils MLIRPass + MLIRShardingInterface MLIRSubsetOpInterface MLIRSparseTensorDialect MLIRTensorDialect diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp new file mode 100644 index 0000000..7ac45dc --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp @@ -0,0 +1,352 @@ +//===- MeshShardingInterfaceImpl.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include <iterator> +#include <optional> +#include <utility> + +namespace mlir::linalg { + +using MeshAxis = mesh::MeshAxis; +using ReductionKind = mesh::ReductionKind; +using MeshShardingAttr = mesh::MeshShardingAttr; +using ShardingArray = mesh::ShardingArray; +using MeshOp = mesh::MeshOp; + +// Returns the corresponding mesh reduction kind for the given arith op. +static ReductionKind getReductionKind(Operation *op) { + return llvm::TypeSwitch<Operation *, ReductionKind>(op) + // Floating-point operations. + .Case([](arith::AddFOp op) { return ReductionKind::Sum; }) + .Case([](arith::MulFOp op) { return ReductionKind::Product; }) + // TODO: handle maxnumf and minnumf. + .Case([](arith::MaximumFOp op) { return ReductionKind::Max; }) + .Case([](arith::MinimumFOp op) { return ReductionKind::Min; }) + // Integer operations. + .Case([](arith::AddIOp op) { return ReductionKind::Sum; }) + .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; }) + .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; }) + .Case([](arith::AndIOp op) { return ReductionKind::Sum; }) + // TODO: handle signless, signed and unsigned types properly. + // It is assumed that the element type of the collective operands and + // result drive the meaning of the reduction kind, whether it is signed + // or unsigned. + // The reduction op inside the linalg op may have different result type + // from the element type of the linalg op's result. + // Also signed and unsigned Arith dialect ops may accept signed, unsigned + // or signless operands. + // Maybe expand the reduction kinds. + .Case([](arith::MaxUIOp op) { return ReductionKind::Max; }) + .Case([](arith::MinUIOp op) { return ReductionKind::Min; }) + .Case([](arith::MaxSIOp op) { return ReductionKind::Max; }) + .Case([](arith::MinSIOp op) { return ReductionKind::Min; }) + .Case([](arith::MulIOp op) { return ReductionKind::Product; }) + .Default([](Operation *op) { return ReductionKind::Generic; }); +} + +static std::optional<Operation *> getCombinerOp(LinalgOp op) { + SmallVector<Operation *> combinerOps; + Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps); + if (!reducedValue || combinerOps.size() != 1) { + return std::nullopt; + } + + return combinerOps[0]; +} + +static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { + std::optional<Operation *> reductionOp = getCombinerOp(op); + if (!reductionOp) { + return ReductionKind::Generic; + } + Type resultElementType = + llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType(); + // TODO: handle case when result type of the reduction op does not match the + // element type of the result tensor. + // Would it makes sense at all? + assert(resultElementType == reductionOp.value()->getResult(0).getType()); + return getReductionKind(reductionOp.value()); +} + +static MeshOp getMesh(Operation *op, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + SymbolTableCollection &symbolTable) { + for (MeshShardingAttr sharding : operandShardings) { + if (sharding) { + return mesh::getMesh(op, sharding.getMesh(), symbolTable); + } + } + + for (MeshShardingAttr sharding : resultShardings) { + if (sharding) { + return mesh::getMesh(op, sharding.getMesh(), symbolTable); + } + } + + assert(false); +} + +// Choose the operand based on the current process index along the reduction +// mesh axes. +// We need to use the initial value only once to avoid including it in the +// reduction multiple times. +// In each process group only the leading process with linear index 0 would use +// the original operand. +// The other processes would use the reduction operation neutral tensor. +static Value createDestinationPassingStyleInitOperand( + LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes, + MeshOp meshOp, ImplicitLocOpBuilder &builder) { + Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( + meshOp.getSymName(), reductionMeshAxes, builder); + Value zero = builder.create<arith::ConstantIndexOp>(0); + Value isLeadProcess = builder.create<arith::CmpIOp>( + builder.getI1Type(), arith::CmpIPredicate::eq, + processLinearIndexInReductionGroup, zero); + scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(), + isLeadProcess, true, true); + // Then block. + { + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); + builder.create<scf::YieldOp>(spmdizedOperand); + } + + // Else block. + { + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); + SmallVector<OpFoldResult> shape = + tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); + PartialReductionOpInterface partialReductionIface = + llvm::cast<PartialReductionOpInterface>(op.getOperation()); + FailureOr<Operation *> reductionNeutralTensorOp = + partialReductionIface.generateInitialTensorForPartialReduction( + builder, builder.getLoc(), shape, {}); + assert(succeeded(reductionNeutralTensorOp)); + builder.create<scf::YieldOp>( + reductionNeutralTensorOp.value()->getResult(0)); + } + return ifOp.getResult(0); +} + +// Create the DPS init operands for the spmdized Linalg op. +// Return all the new spmdized operands. +static SmallVector<Value> createDestinationPassingStyleInitOperands( + LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap, + ImplicitLocOpBuilder &builder) { + // TODO: add support for multiple destination passing style initial value + // operands. + // PartialReductionOpInterface::generateInitialTensorForPartialReduction + // needs to also support multiple DPS initial operands. + SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands); + auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber(); + Value spmdizedInitOperand = + spmdizationMap.lookup(op->getOperands()[operandIdx]); + newOperands[operandIdx] = createDestinationPassingStyleInitOperand( + op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); + return newOperands; +} + +static void createAllReduceForResultWithoutPartialSharding( + Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes, + MeshShardingAttr resultSharding, ReductionKind reductionKind, + IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) { + SmallVector<MeshAxis> allReduceMeshAxes; + llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes), + [&resultSharding](MeshAxis axis) { + return !llvm::is_contained(resultSharding.getPartialAxes(), + axis); + }); + if (allReduceMeshAxes.empty()) { + return; + } + + Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult); + Value reducedValue = builder.create<mesh::AllReduceOp>( + spmdizedLinalgOpResult, resultSharding.getMesh().getValue(), + allReduceMeshAxes, reductionKind); + spmdizationMap.map(unshardedLinalgOpResult, reducedValue); +} + +static void createAllReduceForResultsWithoutPartialShardings( + LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes, + ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap, + ImplicitLocOpBuilder &builder) { + ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp); + for (auto [unshardedLinalgOpResult, resultSharding] : + llvm::zip_equal(unshardedOp->getResults(), resultShardings)) { + createAllReduceForResultWithoutPartialSharding( + unshardedLinalgOpResult, opReductionMeshAxes, resultSharding, + reductionKind, spmdizationMap, builder); + } +} + +static void spmdizeLinalgOpWithShardedReduction( + LinalgOp op, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators, + IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, + ImplicitLocOpBuilder &builder) { + MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable); + SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes( + loopIteratorTypes, meshAxisAssignmentForLoopIterators); + SmallVector<Value> spmdizedLinalgOpOperands = + createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands, + reductionMeshAxes, + spmdizationMap, builder); + // We must not change the operand mappings of the original spmdizationMap as + // they are the mappings for the whole spmdization blob and may be used by + // others. + IRMapping internalSpmdizationMap; + for (auto [unshardedOperand, spmdizedOperand] : + llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) { + internalSpmdizationMap.map(unshardedOperand, spmdizedOperand); + } + spmdizeTriviallyShardableOperation( + *op, spmdizedLinalgOpOperands, operandShardings, resultShardings, + internalSpmdizationMap, symbolTable, builder); + for (Value result : op->getResults()) { + spmdizationMap.map(result, internalSpmdizationMap.lookup(result)); + } + + // Handle partial shardings. + createAllReduceForResultsWithoutPartialShardings( + op, reductionMeshAxes, resultShardings, spmdizationMap, builder); +} + +namespace { + +// ShardingInterface for ops that implement LinalgStructuredInterface. +// The supported ops are only those where the indexing maps are projected +// permutations. +template <typename Op> +struct StructuredOpShardingInterface + : public mesh::ShardingInterface::ExternalModel< + StructuredOpShardingInterface<Op>, Op> { + SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { + return llvm::cast<LinalgOp>(op).getIteratorTypesArray(); + } + + SmallVector<AffineMap> getIndexingMaps(Operation *op) const { + LinalgOp linalgOp = llvm::cast<LinalgOp>(op); + SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray(); + + // Results must have the same indexing as destination passing style initial + // operands. + for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) { + res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]); + } + + return res; + } + + LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + IRMapping &spmdizationMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + LinalgOp linalgOp = llvm::cast<LinalgOp>(op); + + SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); + bool allIndexingMapsAreProjectedPermutation = + llvm::all_of(indexingMaps, [](AffineMap map) { + return map.isProjectedPermutation(); + }); + if (!allIndexingMapsAreProjectedPermutation) { + // TODO: handle non-projected permutations. + return op->emitOpError() + << "supports indexing maps that are only projected permutation."; + } + + SmallVector<utils::IteratorType> loopIteratorTypes = + linalgOp.getIteratorTypesArray(); + ShardingArray meshAxisAssignmentForLoopIterators = + getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings, + loopIteratorTypes, indexingMaps); + if (mesh::isAtLeastOneReductionIteratorSharded( + loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder); + spmdizeLinalgOpWithShardedReduction( + linalgOp, spmdizedOperands, operandShardings, resultShardings, + loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap, + symbolTable, implicitLocBuilder); + } else { + spmdizeTriviallyShardableOperation(*op, spmdizedOperands, + operandShardings, resultShardings, + spmdizationMap, symbolTable, builder); + } + + return success(); + } +}; + +} // namespace + +template <typename OpType> +static void registerOne(MLIRContext *ctx) { + OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx); +} + +/// Variadic helper function. +template <typename... OpTypes> +static void registerAll(MLIRContext *ctx) { + (registerOne<OpTypes>(ctx), ...); +} + +void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) { + DialectRegistry registry; + registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect, + tensor::TensorDialect>(); + ctx->appendDialectRegistry(registry); + for (StringRef name : registry.getDialectNames()) + ctx->getOrLoadDialect(name); + + registerOne<linalg::GenericOp>(ctx); + registerAll< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(ctx); + }); +} + +} // namespace mlir::linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 8b3119f..bd870d4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -275,14 +275,6 @@ struct LinalgOpPartialReductionInterface ArrayRef<int64_t> oldShape = linalgOp.getShape(linalgOp.getDpsInitOperand(0)); - // Extend tile size vector to the rank of the output tensor. - SmallVector<Value> tileSizeVector = - getValueOrCreateConstantIndexOp(b, loc, sizes); - if (tileSizeVector.size() < oldShape.size()) { - auto zero = b.create<arith::ConstantIndexOp>(loc, 0); - tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero); - } - // Calculate the new shape, we insert the new dimensions based on the index // of the reduction dimensions. SmallVector<int64_t> newOutputShape; diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 50163880..03f11ad 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -647,6 +647,13 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context); } +void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState, + Value input, StringRef mesh, + ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) { + build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input, + reduction); +} + void AllReduceOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { setNameFn(getResult(), "all_reduce"); diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index fe3d7c4..9acee5a 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -539,8 +539,9 @@ static bool areValuesCompatibleWithFullReplicationShardings( if (std::size(values) != std::size(shardings)) { return false; } - return llvm::all_of(llvm::zip(std::forward<ValueRange>(values), - std::forward<MeshShardingAttrRage>(shardings)), + return llvm::all_of(llvm::zip_equal( + std::forward<ValueRange>(values), + std::forward<MeshShardingAttrRage>(shardings)), [](auto valueAndSharding) { return isValueCompatibleWithFullReplicationSharding( std::get<0>(valueAndSharding), @@ -563,6 +564,88 @@ void mesh::spmdizeFullyReplicatedOperation( builder.clone(op, spmdizationMap); } +static void updateMeshAxisAssignmentForLoopIterators( + ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, + SmallVector<std::optional<SmallVector<MeshAxis>>> + &meshAxesAssignmentForLoopIterators) { + AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr); + unsigned loopIteratorIdx = affineDimExpr.getPosition(); + if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { + assert(llvm::equal(meshAxesAssignmentForTensorAxis, + *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); + } else { + meshAxesAssignmentForLoopIterators[loopIteratorIdx] = + llvm::to_vector(meshAxesAssignmentForTensorAxis); + } +} + +ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<AffineMap> indexingMaps) { + SmallVector<std::optional<SmallVector<MeshAxis>>> + meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); + SmallVector<MeshShardingAttr> operatorAndResultShardings; + operatorAndResultShardings.reserve(operandShardings.size() + + resultShardings.size()); + llvm::append_range(operatorAndResultShardings, operandShardings); + for (auto [sharding, affineMap] : + llvm::zip_equal(operatorAndResultShardings, indexingMaps)) { + if (!sharding) { + continue; + } + for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : + llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { + updateMeshAxisAssignmentForLoopIterators( + meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, + meshAxisAssignmentForLoopIterators); + } + // Missing trailing split axes means replication on those tensor dimensions. + for (unsigned i = sharding.getSplitAxes().size(); + i < affineMap.getNumResults(); ++i) { + updateMeshAxisAssignmentForLoopIterators( + {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators); + } + } + + ShardingArray res; + llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res), + [](std::optional<SmallVector<MeshAxis>> &axes) { + if (!axes) { + return SmallVector<MeshAxis>(); + }; + return std::move(*axes); + }); + return res; +} + +bool mesh::isAtLeastOneReductionIteratorSharded( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { + for (auto [loopIteratorType, meshAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + if (loopIteratorType == utils::IteratorType::reduction && + !meshAxisAssignment.empty()) { + return true; + } + } + return false; +} + +SmallVector<MeshAxis> mesh::getReductionMeshAxes( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { + SmallVector<MeshAxis> meshAxes; + for (auto [loopIteratorType, meshAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + if (loopIteratorType == utils::IteratorType::reduction) { + llvm::append_range(meshAxes, meshAxisAssignment); + } + } + return meshAxes; +} + void mesh::spmdizeTriviallyShardableOperation( Operation &op, ArrayRef<Value> spmdizedOperands, ArrayRef<MeshShardingAttr> operandShardings, @@ -572,7 +655,7 @@ void mesh::spmdizeTriviallyShardableOperation( Operation *newOp = builder.clone(op, spmdizationMap); // Set the result types to the sharded counterparts. for (auto [oldResult, newResult, sharding] : - llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) { + llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { newResult.setType(shardType(newResult.getType(), getMesh(&op, sharding.getMesh(), symbolTable), sharding)); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index d59b911..cb13ee4 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -208,4 +208,17 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes, .cast<TypedValue<IndexType>>(); } +TypedValue<IndexType> createProcessLinearIndex(StringRef mesh, + ArrayRef<MeshAxis> meshAxes, + ImplicitLocOpBuilder &builder) { + ResultRange processInGroupMultiIndex = + builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults(); + Operation::result_range processGroupShape = + builder.create<MeshShapeOp>(mesh, meshAxes).getResult(); + OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( + llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), + llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); + return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>()); +} + } // namespace mlir::mesh diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir new file mode 100644 index 0000000..6d21def --- /dev/null +++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir @@ -0,0 +1,165 @@ +// RUN: mlir-opt \ +// RUN: --mesh-spmdization \ +// RUN: --test-constant-fold \ +// RUN: --split-input-file \ +// RUN: %s | FileCheck %s + +// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)> +#map_identity_1d = affine_map<(d0) -> (d0)> + +mesh.mesh @mesh_1d(shape = 2) + +// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor +func.func @elementwise_static_1d_mesh_static_1d_tensor( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>, + %in1: tensor<2xi8>, + // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>, + %in2: tensor<2xi8>, + // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1xi8> + %dps_out: tensor<2xi8> +// CHECK-SAME: -> tensor<1xi8> { +) -> tensor<2xi8> { + %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<2xi8> + %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<2xi8> + %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<2xi8> + %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + // CHECK: %[[RES:.*]] = linalg.generic { + // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]], + // CHECK-SAME: iterator_types = ["parallel"]} + // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1xi8>, tensor<1xi8>) + // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1xi8>) { + %res = linalg.generic { + indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d], + iterator_types = ["parallel"] + } ins(%in1_shared2, %in2_shared2 : tensor<2xi8>, tensor<2xi8>) + outs(%dps_out_shared2 : tensor<2xi8>) { + ^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8): + %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8 + linalg.yield %res_scalar : i8 + } -> tensor<2xi8> + %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<2xi8> + %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + // CHECK: return %[[RES]] : tensor<1xi8> + return %res_shared2 : tensor<2xi8> +} + +// ----- + +mesh.mesh @mesh_1d(shape = 4) + +// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding +func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>, + %in1: tensor<4x3xi8>, +// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>, + %in2: tensor<3x8xi8>, +// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1x8xi8> + %dps_out: tensor<4x8xi8> +// CHECK-SAME: -> tensor<1x8xi8> { +) -> tensor<4x8xi8> { + %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<4x3xi8> + %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x3xi8> + %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[]]> : tensor<3x8xi8> + %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<3x8xi8> + %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<4x8xi8> + %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8> + // CHECK: %[[RES:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>) + // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>) + // CHECK-SAME: -> tensor<1x8xi8> + %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>) + outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> + %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<4x8xi8> + %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8> + // CHECK: return %[[RES]] : tensor<1x8xi8> + return %res_shared2 : tensor<4x8xi8> +} + +// ----- + +mesh.mesh @mesh_1d(shape = 3) + +// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding +func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>, + %in1: tensor<4x6xi8>, +// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>, + %in2: tensor<6x8xi8>, +// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8> + %dps_out: tensor<4x8xi8> +// CHECK-SAME: -> tensor<4x8xi8> { +) -> tensor<4x8xi8> { + %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8> + %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8> + %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8> + %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8> + %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8> + %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 + // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index + // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index + // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) { + // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8> + // CHECK: } else { + // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8> + // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8) + // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8> + // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8> + // CHECK: } + // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>) + // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8> + // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8> + %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>) + outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> + %res_shared1 = mesh.shard %res to <@mesh_1d, [[]]> : tensor<4x8xi8> + %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8> + // CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8> + return %res_shared2 : tensor<4x8xi8> +} + +// ----- + +mesh.mesh @mesh_1d(shape = 3) + +// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result +func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>, + %in1: tensor<4x6xi8>, +// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>, + %in2: tensor<6x8xi8>, +// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8> + %dps_out: tensor<4x8xi8> +// CHECK-SAME: -> tensor<4x8xi8> { +) -> tensor<4x8xi8> { + %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8> + %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8> + %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8> + %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8> + %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8> + %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 + // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index + // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index + // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) { + // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8> + // CHECK: } else { + // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8> + // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8) + // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8> + // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8> + // CHECK: } + // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>) + // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8> + %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>) + outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> + %res_shared1 = mesh.shard %res to <@mesh_1d, [[]], partial = sum[0]> : tensor<4x8xi8> + %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]], partial = sum[0]> annotate_for_users: tensor<4x8xi8> + // CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8> + return %res_shared2 : tensor<4x8xi8> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 7a6bc2d..2cfe6184 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10841,6 +10841,7 @@ cc_library( ":MemRefDialect", ":Parser", ":SCFDialect", + ":MeshShardingInterface", ":SideEffectInterfaces", ":SparseTensorDialect", ":Support", @@ -10994,10 +10995,13 @@ cc_library( ":MathDialect", ":MemRefDialect", ":MemRefTransforms", + ":MeshDialect", + ":MeshTransforms", ":Pass", ":SCFDialect", ":SCFTransforms", ":SCFUtils", + ":MeshShardingInterface", ":SparseTensorDialect", ":SubsetOpInterface", ":Support", |