//===- ShardingInterfaceImpl.cpp --------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include #include namespace mlir::linalg { using GridAxis = shard::GridAxis; using ReductionKind = shard::ReductionKind; using Sharding = shard::Sharding; using ShardingArray = shard::ShardingArray; using GridOp = shard::GridOp; // Returns the corresponding grid reduction kind for the given arith op. static ReductionKind getReductionKind(Operation *op) { return llvm::TypeSwitch(op) // Floating-point operations. .Case([](arith::AddFOp op) { return ReductionKind::Sum; }) .Case([](arith::MulFOp op) { return ReductionKind::Product; }) // TODO: handle maxnumf and minnumf. .Case([](arith::MaximumFOp op) { return ReductionKind::Max; }) .Case([](arith::MinimumFOp op) { return ReductionKind::Min; }) // Integer operations. .Case([](arith::AddIOp op) { return ReductionKind::Sum; }) .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; }) .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; }) .Case([](arith::AndIOp op) { return ReductionKind::Sum; }) // TODO: handle signless, signed and unsigned types properly. // It is assumed that the element type of the collective operands and // result drive the meaning of the reduction kind, whether it is signed // or unsigned. // The reduction op inside the linalg op may have different result type // from the element type of the linalg op's result. // Also signed and unsigned Arith dialect ops may accept signed, unsigned // or signless operands. // Maybe expand the reduction kinds. .Case([](arith::MaxUIOp op) { return ReductionKind::Max; }) .Case([](arith::MinUIOp op) { return ReductionKind::Min; }) .Case([](arith::MaxSIOp op) { return ReductionKind::Max; }) .Case([](arith::MinSIOp op) { return ReductionKind::Min; }) .Case([](arith::MulIOp op) { return ReductionKind::Product; }) .Default([](Operation *op) { return ReductionKind::Generic; }); } static std::optional getCombinerOp(LinalgOp op) { SmallVector combinerOps; Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps); if (!reducedValue || combinerOps.size() != 1) { return std::nullopt; } return combinerOps[0]; } static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { std::optional reductionOp = getCombinerOp(op); if (!reductionOp) { return ReductionKind::Generic; } [[maybe_unused]] Type resultElementType = llvm::cast(op->getResult(0).getType()).getElementType(); // TODO: handle case when result type of the reduction op does not match the // element type of the result tensor. // Would it makes sense at all? assert(resultElementType == reductionOp.value()->getResult(0).getType()); return getReductionKind(reductionOp.value()); } static GridOp getGrid(Operation *op, ArrayRef operandShardings, ArrayRef resultShardings, SymbolTableCollection &symbolTable) { for (const Sharding &sharding : operandShardings) { if (sharding) { return shard::getGrid(op, sharding.getGridAttr(), symbolTable); } } for (const Sharding &sharding : resultShardings) { if (sharding) { return shard::getGrid(op, sharding.getGridAttr(), symbolTable); } } assert(false); return nullptr; } // Choose the operand based on the current process index along the reduction // grid axes. // We need to use the initial value only once to avoid including it in the // reduction multiple times. // In each process group only the leading process with linear index 0 would use // the original operand. // The other processes would use the reduction operation neutral tensor. static Value createDestinationPassingStyleInitOperand( LinalgOp op, int operandNumber, Value partitionedOperand, ArrayRef reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder) { Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex( gridOp.getSymName(), reductionGridAxes, builder); Value zero = arith::ConstantIndexOp::create(builder, 0); Value isLeadProcess = arith::CmpIOp::create( builder, builder.getI1Type(), arith::CmpIPredicate::eq, processLinearIndexInReductionGroup, zero); scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.getType(), isLeadProcess, true, true); // Then block. { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); scf::YieldOp::create(builder, partitionedOperand); } // Else block. { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); SmallVector shape = tensor::getMixedSizes(builder, builder.getLoc(), partitionedOperand); SmallVector combinerOps; matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps); assert(combinerOps.size() == 1); std::optional neutralEl = arith::getNeutralElement(combinerOps[0]); Value init = tensor::EmptyOp::create(builder, op.getLoc(), shape, neutralEl.value().getType()); Value constant = arith::ConstantOp::create(builder, op.getLoc(), neutralEl.value()); Value fill = linalg::FillOp::create(builder, op.getLoc(), constant, init) .getResult(0); scf::YieldOp::create(builder, fill); } return ifOp.getResult(0); } // Create the DPS init operands for the partitioned Linalg op. // Return all the new partitioned operands. static SmallVector createDestinationPassingStyleInitOperands( LinalgOp op, GridOp gridOp, ArrayRef partitionedOperands, ArrayRef reductionGridAxes, IRMapping &partitionMap, ImplicitLocOpBuilder &builder) { // TODO: add support for multiple destination passing style initial value // operands. assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported."); SmallVector newOperands = llvm::to_vector(partitionedOperands); auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber(); Value partitionedInitOperand = partitionMap.lookup(op->getOperands()[operandIdx]); newOperands[operandIdx] = createDestinationPassingStyleInitOperand( op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder); return newOperands; } static void createAllReduceForResultsWithoutPartialShardings( LinalgOp unshardedOp, ArrayRef opReductionGridAxes, ArrayRef resultShardings, IRMapping &partitionMap, ImplicitLocOpBuilder &builder) { ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp); for (auto [unshardedLinalgOpResult, resultSharding] : llvm::zip_equal(unshardedOp->getResults(), resultShardings)) { Value partitionedLinalgOpResult = partitionMap.lookup(unshardedLinalgOpResult); Value reducedValue = shard::AllReduceOp::create( builder, partitionedLinalgOpResult, resultSharding.getGrid(), opReductionGridAxes, reductionKind); partitionMap.map(unshardedLinalgOpResult, reducedValue); } } static void partitionLinalgOpWithShardedReduction( LinalgOp op, ArrayRef partitionedOperands, ArrayRef operandShardings, ArrayRef resultShardings, ArrayRef loopIteratorTypes, ArrayRef> gridAxisAssignmentForLoopIterators, IRMapping &partitionMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder) { GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable); SmallVector reductionGridAxes = shard::getReductionGridAxes( loopIteratorTypes, gridAxisAssignmentForLoopIterators); SmallVector partitionedLinalgOpOperands = createDestinationPassingStyleInitOperands(op, grid, partitionedOperands, reductionGridAxes, partitionMap, builder); // We must not change the operand mappings of the original partitionMap as // they are the mappings for the whole partition blob and may be used by // others. IRMapping internalPartitionMap; for (auto [unshardedOperand, partitionedOperand] : llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) { internalPartitionMap.map(unshardedOperand, partitionedOperand); } partitionTriviallyShardableOperation( *op, partitionedLinalgOpOperands, operandShardings, resultShardings, internalPartitionMap, symbolTable, builder); for (Value result : op->getResults()) { partitionMap.map(result, internalPartitionMap.lookup(result)); } // Handle partial shardings. createAllReduceForResultsWithoutPartialShardings( op, reductionGridAxes, resultShardings, partitionMap, builder); } namespace { // ShardingInterface for ops that implement LinalgStructuredInterface. // The supported ops are only those where the indexing maps are projected // permutations. template struct StructuredOpShardingInterface : public shard::ShardingInterface::ExternalModel< StructuredOpShardingInterface, Op> { SmallVector getLoopIteratorTypes(Operation *op) const { return llvm::cast(op).getIteratorTypesArray(); } SmallVector getIndexingMaps(Operation *op) const { LinalgOp linalgOp = llvm::cast(op); SmallVector res = linalgOp.getIndexingMapsArray(); // Results must have the same indexing as destination passing style initial // operands. for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) { res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]); } return res; } SmallVector getReductionLoopIteratorKinds(Operation *op) const { LinalgOp linalgOp = llvm::cast(op); SmallVector iteratorTypes = linalgOp.getIteratorTypesArray(); unsigned reductionItersCount = std::accumulate( iteratorTypes.begin(), iteratorTypes.end(), 0, [](unsigned count, utils::IteratorType iter) { return count + (iter == utils::IteratorType::reduction); }); shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); return SmallVector(reductionItersCount, reductionKind); } LogicalResult partition(Operation *op, ArrayRef partitionedOperands, ArrayRef operandShardings, ArrayRef resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const { LinalgOp linalgOp = llvm::cast(op); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); bool allIndexingMapsAreProjectedPermutation = llvm::all_of(indexingMaps, [](AffineMap map) { return map.isProjectedPermutation(); }); if (!allIndexingMapsAreProjectedPermutation) { // TODO: handle non-projected permutations. return op->emitOpError() << "supports indexing maps that are only projected permutation."; } SmallVector loopIteratorTypes = linalgOp.getIteratorTypesArray(); ShardingArray gridAxisAssignmentForLoopIterators = getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings, loopIteratorTypes, indexingMaps); if (shard::isAtLeastOneReductionIteratorSharded( loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder); partitionLinalgOpWithShardedReduction( linalgOp, partitionedOperands, operandShardings, resultShardings, loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap, symbolTable, implicitLocBuilder); } else { partitionTriviallyShardableOperation(*op, partitionedOperands, operandShardings, resultShardings, partitionMap, symbolTable, builder); } return success(); } }; } // namespace template static void registerOne(MLIRContext *ctx) { OpType::template attachInterface>(*ctx); } /// Variadic helper function. template static void registerAll(MLIRContext *ctx) { (registerOne(ctx), ...); } void registerShardingInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) { DialectRegistry registry; registry.insert(); ctx->appendDialectRegistry(registry); for (StringRef name : registry.getDialectNames()) ctx->getOrLoadDialect(name); registerOne(ctx); registerAll< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(ctx); }); } } // namespace mlir::linalg