diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp | 353 |
1 files changed, 353 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp new file mode 100644 index 0000000..146e880 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp @@ -0,0 +1,353 @@ +//===- MeshShardingInterfaceImpl.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include <iterator> +#include <optional> +#include <utility> + +namespace mlir::linalg { + +using MeshAxis = mesh::MeshAxis; +using ReductionKind = mesh::ReductionKind; +using MeshShardingAttr = mesh::MeshShardingAttr; +using ShardingArray = mesh::ShardingArray; +using MeshOp = mesh::MeshOp; + +// Returns the corresponding mesh reduction kind for the given arith op. +static ReductionKind getReductionKind(Operation *op) { + return llvm::TypeSwitch<Operation *, ReductionKind>(op) + // Floating-point operations. + .Case([](arith::AddFOp op) { return ReductionKind::Sum; }) + .Case([](arith::MulFOp op) { return ReductionKind::Product; }) + // TODO: handle maxnumf and minnumf. + .Case([](arith::MaximumFOp op) { return ReductionKind::Max; }) + .Case([](arith::MinimumFOp op) { return ReductionKind::Min; }) + // Integer operations. + .Case([](arith::AddIOp op) { return ReductionKind::Sum; }) + .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; }) + .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; }) + .Case([](arith::AndIOp op) { return ReductionKind::Sum; }) + // TODO: handle signless, signed and unsigned types properly. + // It is assumed that the element type of the collective operands and + // result drive the meaning of the reduction kind, whether it is signed + // or unsigned. + // The reduction op inside the linalg op may have different result type + // from the element type of the linalg op's result. + // Also signed and unsigned Arith dialect ops may accept signed, unsigned + // or signless operands. + // Maybe expand the reduction kinds. + .Case([](arith::MaxUIOp op) { return ReductionKind::Max; }) + .Case([](arith::MinUIOp op) { return ReductionKind::Min; }) + .Case([](arith::MaxSIOp op) { return ReductionKind::Max; }) + .Case([](arith::MinSIOp op) { return ReductionKind::Min; }) + .Case([](arith::MulIOp op) { return ReductionKind::Product; }) + .Default([](Operation *op) { return ReductionKind::Generic; }); +} + +static std::optional<Operation *> getCombinerOp(LinalgOp op) { + SmallVector<Operation *> combinerOps; + Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps); + if (!reducedValue || combinerOps.size() != 1) { + return std::nullopt; + } + + return combinerOps[0]; +} + +static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { + std::optional<Operation *> reductionOp = getCombinerOp(op); + if (!reductionOp) { + return ReductionKind::Generic; + } + [[maybe_unused]] Type resultElementType = + llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType(); + // TODO: handle case when result type of the reduction op does not match the + // element type of the result tensor. + // Would it makes sense at all? + assert(resultElementType == reductionOp.value()->getResult(0).getType()); + return getReductionKind(reductionOp.value()); +} + +static MeshOp getMesh(Operation *op, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + SymbolTableCollection &symbolTable) { + for (MeshShardingAttr sharding : operandShardings) { + if (sharding) { + return mesh::getMesh(op, sharding.getMesh(), symbolTable); + } + } + + for (MeshShardingAttr sharding : resultShardings) { + if (sharding) { + return mesh::getMesh(op, sharding.getMesh(), symbolTable); + } + } + + assert(false); + return nullptr; +} + +// Choose the operand based on the current process index along the reduction +// mesh axes. +// We need to use the initial value only once to avoid including it in the +// reduction multiple times. +// In each process group only the leading process with linear index 0 would use +// the original operand. +// The other processes would use the reduction operation neutral tensor. +static Value createDestinationPassingStyleInitOperand( + LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes, + MeshOp meshOp, ImplicitLocOpBuilder &builder) { + Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( + meshOp.getSymName(), reductionMeshAxes, builder); + Value zero = builder.create<arith::ConstantIndexOp>(0); + Value isLeadProcess = builder.create<arith::CmpIOp>( + builder.getI1Type(), arith::CmpIPredicate::eq, + processLinearIndexInReductionGroup, zero); + scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(), + isLeadProcess, true, true); + // Then block. + { + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); + builder.create<scf::YieldOp>(spmdizedOperand); + } + + // Else block. + { + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); + SmallVector<OpFoldResult> shape = + tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); + PartialReductionOpInterface partialReductionIface = + llvm::cast<PartialReductionOpInterface>(op.getOperation()); + FailureOr<Operation *> reductionNeutralTensorOp = + partialReductionIface.generateInitialTensorForPartialReduction( + builder, builder.getLoc(), shape, {}); + assert(succeeded(reductionNeutralTensorOp)); + builder.create<scf::YieldOp>( + reductionNeutralTensorOp.value()->getResult(0)); + } + return ifOp.getResult(0); +} + +// Create the DPS init operands for the spmdized Linalg op. +// Return all the new spmdized operands. +static SmallVector<Value> createDestinationPassingStyleInitOperands( + LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap, + ImplicitLocOpBuilder &builder) { + // TODO: add support for multiple destination passing style initial value + // operands. + // PartialReductionOpInterface::generateInitialTensorForPartialReduction + // needs to also support multiple DPS initial operands. + SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands); + auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber(); + Value spmdizedInitOperand = + spmdizationMap.lookup(op->getOperands()[operandIdx]); + newOperands[operandIdx] = createDestinationPassingStyleInitOperand( + op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); + return newOperands; +} + +static void createAllReduceForResultWithoutPartialSharding( + Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes, + MeshShardingAttr resultSharding, ReductionKind reductionKind, + IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) { + SmallVector<MeshAxis> allReduceMeshAxes; + llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes), + [&resultSharding](MeshAxis axis) { + return !llvm::is_contained(resultSharding.getPartialAxes(), + axis); + }); + if (allReduceMeshAxes.empty()) { + return; + } + + Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult); + Value reducedValue = builder.create<mesh::AllReduceOp>( + spmdizedLinalgOpResult, resultSharding.getMesh().getValue(), + allReduceMeshAxes, reductionKind); + spmdizationMap.map(unshardedLinalgOpResult, reducedValue); +} + +static void createAllReduceForResultsWithoutPartialShardings( + LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes, + ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap, + ImplicitLocOpBuilder &builder) { + ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp); + for (auto [unshardedLinalgOpResult, resultSharding] : + llvm::zip_equal(unshardedOp->getResults(), resultShardings)) { + createAllReduceForResultWithoutPartialSharding( + unshardedLinalgOpResult, opReductionMeshAxes, resultSharding, + reductionKind, spmdizationMap, builder); + } +} + +static void spmdizeLinalgOpWithShardedReduction( + LinalgOp op, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators, + IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, + ImplicitLocOpBuilder &builder) { + MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable); + SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes( + loopIteratorTypes, meshAxisAssignmentForLoopIterators); + SmallVector<Value> spmdizedLinalgOpOperands = + createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands, + reductionMeshAxes, + spmdizationMap, builder); + // We must not change the operand mappings of the original spmdizationMap as + // they are the mappings for the whole spmdization blob and may be used by + // others. + IRMapping internalSpmdizationMap; + for (auto [unshardedOperand, spmdizedOperand] : + llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) { + internalSpmdizationMap.map(unshardedOperand, spmdizedOperand); + } + spmdizeTriviallyShardableOperation( + *op, spmdizedLinalgOpOperands, operandShardings, resultShardings, + internalSpmdizationMap, symbolTable, builder); + for (Value result : op->getResults()) { + spmdizationMap.map(result, internalSpmdizationMap.lookup(result)); + } + + // Handle partial shardings. + createAllReduceForResultsWithoutPartialShardings( + op, reductionMeshAxes, resultShardings, spmdizationMap, builder); +} + +namespace { + +// ShardingInterface for ops that implement LinalgStructuredInterface. +// The supported ops are only those where the indexing maps are projected +// permutations. +template <typename Op> +struct StructuredOpShardingInterface + : public mesh::ShardingInterface::ExternalModel< + StructuredOpShardingInterface<Op>, Op> { + SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { + return llvm::cast<LinalgOp>(op).getIteratorTypesArray(); + } + + SmallVector<AffineMap> getIndexingMaps(Operation *op) const { + LinalgOp linalgOp = llvm::cast<LinalgOp>(op); + SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray(); + + // Results must have the same indexing as destination passing style initial + // operands. + for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) { + res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]); + } + + return res; + } + + LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + IRMapping &spmdizationMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + LinalgOp linalgOp = llvm::cast<LinalgOp>(op); + + SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); + bool allIndexingMapsAreProjectedPermutation = + llvm::all_of(indexingMaps, [](AffineMap map) { + return map.isProjectedPermutation(); + }); + if (!allIndexingMapsAreProjectedPermutation) { + // TODO: handle non-projected permutations. + return op->emitOpError() + << "supports indexing maps that are only projected permutation."; + } + + SmallVector<utils::IteratorType> loopIteratorTypes = + linalgOp.getIteratorTypesArray(); + ShardingArray meshAxisAssignmentForLoopIterators = + getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings, + loopIteratorTypes, indexingMaps); + if (mesh::isAtLeastOneReductionIteratorSharded( + loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder); + spmdizeLinalgOpWithShardedReduction( + linalgOp, spmdizedOperands, operandShardings, resultShardings, + loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap, + symbolTable, implicitLocBuilder); + } else { + spmdizeTriviallyShardableOperation(*op, spmdizedOperands, + operandShardings, resultShardings, + spmdizationMap, symbolTable, builder); + } + + return success(); + } +}; + +} // namespace + +template <typename OpType> +static void registerOne(MLIRContext *ctx) { + OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx); +} + +/// Variadic helper function. +template <typename... OpTypes> +static void registerAll(MLIRContext *ctx) { + (registerOne<OpTypes>(ctx), ...); +} + +void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) { + DialectRegistry registry; + registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect, + tensor::TensorDialect>(); + ctx->appendDialectRegistry(registry); + for (StringRef name : registry.getDialectNames()) + ctx->getOrLoadDialect(name); + + registerOne<linalg::GenericOp>(ctx); + registerAll< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(ctx); + }); +} + +} // namespace mlir::linalg |