aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp353
1 files changed, 353 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
new file mode 100644
index 0000000..146e880
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -0,0 +1,353 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+
+namespace mlir::linalg {
+
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+
+// Returns the corresponding mesh reduction kind for the given arith op.
+static ReductionKind getReductionKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ReductionKind>(op)
+ // Floating-point operations.
+ .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
+ .Case([](arith::MulFOp op) { return ReductionKind::Product; })
+ // TODO: handle maxnumf and minnumf.
+ .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
+ // Integer operations.
+ .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
+ .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
+ .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
+ .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+ // TODO: handle signless, signed and unsigned types properly.
+ // It is assumed that the element type of the collective operands and
+ // result drive the meaning of the reduction kind, whether it is signed
+ // or unsigned.
+ // The reduction op inside the linalg op may have different result type
+ // from the element type of the linalg op's result.
+ // Also signed and unsigned Arith dialect ops may accept signed, unsigned
+ // or signless operands.
+ // Maybe expand the reduction kinds.
+ .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
+ .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
+ .Case([](arith::MulIOp op) { return ReductionKind::Product; })
+ .Default([](Operation *op) { return ReductionKind::Generic; });
+}
+
+static std::optional<Operation *> getCombinerOp(LinalgOp op) {
+ SmallVector<Operation *> combinerOps;
+ Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
+ if (!reducedValue || combinerOps.size() != 1) {
+ return std::nullopt;
+ }
+
+ return combinerOps[0];
+}
+
+static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
+ std::optional<Operation *> reductionOp = getCombinerOp(op);
+ if (!reductionOp) {
+ return ReductionKind::Generic;
+ }
+ [[maybe_unused]] Type resultElementType =
+ llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
+ // TODO: handle case when result type of the reduction op does not match the
+ // element type of the result tensor.
+ // Would it makes sense at all?
+ assert(resultElementType == reductionOp.value()->getResult(0).getType());
+ return getReductionKind(reductionOp.value());
+}
+
+static MeshOp getMesh(Operation *op,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ SymbolTableCollection &symbolTable) {
+ for (MeshShardingAttr sharding : operandShardings) {
+ if (sharding) {
+ return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ }
+ }
+
+ for (MeshShardingAttr sharding : resultShardings) {
+ if (sharding) {
+ return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ }
+ }
+
+ assert(false);
+ return nullptr;
+}
+
+// Choose the operand based on the current process index along the reduction
+// mesh axes.
+// We need to use the initial value only once to avoid including it in the
+// reduction multiple times.
+// In each process group only the leading process with linear index 0 would use
+// the original operand.
+// The other processes would use the reduction operation neutral tensor.
+static Value createDestinationPassingStyleInitOperand(
+ LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
+ MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+ Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
+ meshOp.getSymName(), reductionMeshAxes, builder);
+ Value zero = builder.create<arith::ConstantIndexOp>(0);
+ Value isLeadProcess = builder.create<arith::CmpIOp>(
+ builder.getI1Type(), arith::CmpIPredicate::eq,
+ processLinearIndexInReductionGroup, zero);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+ isLeadProcess, true, true);
+ // Then block.
+ {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+ builder.create<scf::YieldOp>(spmdizedOperand);
+ }
+
+ // Else block.
+ {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
+ SmallVector<OpFoldResult> shape =
+ tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+ PartialReductionOpInterface partialReductionIface =
+ llvm::cast<PartialReductionOpInterface>(op.getOperation());
+ FailureOr<Operation *> reductionNeutralTensorOp =
+ partialReductionIface.generateInitialTensorForPartialReduction(
+ builder, builder.getLoc(), shape, {});
+ assert(succeeded(reductionNeutralTensorOp));
+ builder.create<scf::YieldOp>(
+ reductionNeutralTensorOp.value()->getResult(0));
+ }
+ return ifOp.getResult(0);
+}
+
+// Create the DPS init operands for the spmdized Linalg op.
+// Return all the new spmdized operands.
+static SmallVector<Value> createDestinationPassingStyleInitOperands(
+ LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+ ImplicitLocOpBuilder &builder) {
+ // TODO: add support for multiple destination passing style initial value
+ // operands.
+ // PartialReductionOpInterface::generateInitialTensorForPartialReduction
+ // needs to also support multiple DPS initial operands.
+ SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+ auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
+ Value spmdizedInitOperand =
+ spmdizationMap.lookup(op->getOperands()[operandIdx]);
+ newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
+ op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+ return newOperands;
+}
+
+static void createAllReduceForResultWithoutPartialSharding(
+ Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
+ MeshShardingAttr resultSharding, ReductionKind reductionKind,
+ IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
+ SmallVector<MeshAxis> allReduceMeshAxes;
+ llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
+ [&resultSharding](MeshAxis axis) {
+ return !llvm::is_contained(resultSharding.getPartialAxes(),
+ axis);
+ });
+ if (allReduceMeshAxes.empty()) {
+ return;
+ }
+
+ Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
+ Value reducedValue = builder.create<mesh::AllReduceOp>(
+ spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
+ allReduceMeshAxes, reductionKind);
+ spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+}
+
+static void createAllReduceForResultsWithoutPartialShardings(
+ LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ ImplicitLocOpBuilder &builder) {
+ ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
+ for (auto [unshardedLinalgOpResult, resultSharding] :
+ llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
+ createAllReduceForResultWithoutPartialSharding(
+ unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
+ reductionKind, spmdizationMap, builder);
+ }
+}
+
+static void spmdizeLinalgOpWithShardedReduction(
+ LinalgOp op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
+ IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+ ImplicitLocOpBuilder &builder) {
+ MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
+ SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators);
+ SmallVector<Value> spmdizedLinalgOpOperands =
+ createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
+ reductionMeshAxes,
+ spmdizationMap, builder);
+ // We must not change the operand mappings of the original spmdizationMap as
+ // they are the mappings for the whole spmdization blob and may be used by
+ // others.
+ IRMapping internalSpmdizationMap;
+ for (auto [unshardedOperand, spmdizedOperand] :
+ llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
+ internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+ }
+ spmdizeTriviallyShardableOperation(
+ *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
+ internalSpmdizationMap, symbolTable, builder);
+ for (Value result : op->getResults()) {
+ spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+ }
+
+ // Handle partial shardings.
+ createAllReduceForResultsWithoutPartialShardings(
+ op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+}
+
+namespace {
+
+// ShardingInterface for ops that implement LinalgStructuredInterface.
+// The supported ops are only those where the indexing maps are projected
+// permutations.
+template <typename Op>
+struct StructuredOpShardingInterface
+ : public mesh::ShardingInterface::ExternalModel<
+ StructuredOpShardingInterface<Op>, Op> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+ SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
+
+ // Results must have the same indexing as destination passing style initial
+ // operands.
+ for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
+ res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
+ }
+
+ return res;
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+
+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+ bool allIndexingMapsAreProjectedPermutation =
+ llvm::all_of(indexingMaps, [](AffineMap map) {
+ return map.isProjectedPermutation();
+ });
+ if (!allIndexingMapsAreProjectedPermutation) {
+ // TODO: handle non-projected permutations.
+ return op->emitOpError()
+ << "supports indexing maps that are only projected permutation.";
+ }
+
+ SmallVector<utils::IteratorType> loopIteratorTypes =
+ linalgOp.getIteratorTypesArray();
+ ShardingArray meshAxisAssignmentForLoopIterators =
+ getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
+ loopIteratorTypes, indexingMaps);
+ if (mesh::isAtLeastOneReductionIteratorSharded(
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
+ spmdizeLinalgOpWithShardedReduction(
+ linalgOp, spmdizedOperands, operandShardings, resultShardings,
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
+ symbolTable, implicitLocBuilder);
+ } else {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
+ operandShardings, resultShardings,
+ spmdizationMap, symbolTable, builder);
+ }
+
+ return success();
+ }
+};
+
+} // namespace
+
+template <typename OpType>
+static void registerOne(MLIRContext *ctx) {
+ OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerAll(MLIRContext *ctx) {
+ (registerOne<OpTypes>(ctx), ...);
+}
+
+void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
+ DialectRegistry registry;
+ registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
+ tensor::TensorDialect>();
+ ctx->appendDialectRegistry(registry);
+ for (StringRef name : registry.getDialectNames())
+ ctx->getOrLoadDialect(name);
+
+ registerOne<linalg::GenericOp>(ctx);
+ registerAll<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(ctx);
+ });
+}
+
+} // namespace mlir::linalg