//===- ShardingInterfaceImpl.cpp ------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/IR/DialectRegistry.h" using namespace mlir; using namespace mlir::arith; using namespace mlir::shard; namespace { // Sharding of arith.constant // RankedTensor constants can be sharded like any other tensor. // %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> // %sharding = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding // Scalar constants are always replicated and need no sharding annotation. struct ConstantShardingInterface : public ShardingInterface::ExternalModel { SmallVector getLoopIteratorTypes(Operation *op) const { auto ndims = 0; if (auto type = dyn_cast(op->getResult(0).getType())) { ndims = type.getRank(); } return SmallVector(ndims, utils::IteratorType::parallel); } SmallVector getIndexingMaps(Operation *op) const { if (auto type = dyn_cast(op->getResult(0).getType())) { return SmallVector(1, {AffineMap::getMultiDimIdentityMap( type.getRank(), op->getContext())}); } return {}; } // Indicate failure if no result sharding exists. // Otherwise mirror result sharding if it is a tensor constant. // Otherwise return replication option. FailureOr getShardingOption(Operation *op, ArrayRef operandShardings, ArrayRef resultShardings) const { assert(resultShardings.size() == 1 && "Expecting exactly one result sharding for arith.constant"); auto resultSharding = resultShardings[0]; if (!resultSharding) { return failure(); } if (auto type = dyn_cast(op->getResult(0).getType())) { ShardingArray axesArray(resultSharding.getSplitAxes().size()); for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) { axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end()); } return ShardingOption(axesArray, resultSharding.getGridAttr()); } return ShardingOption({}, resultSharding.getGridAttr()); } LogicalResult partition(Operation *op, ArrayRef partitiondOperands, ArrayRef operandShardings, ArrayRef resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const { auto cOp = cast(op); if (auto value = dyn_cast(cOp.getValue())) { if (!value.isSplat() || !resultShardings[0]) { // Currently non-splat constants are not supported. return failure(); } auto sharding = resultShardings[0]; auto newType = cast(shardType( cOp.getType(), getGrid(op, sharding.getGridAttr(), symbolTable), sharding)); auto newValue = value.resizeSplat(newType); auto newOp = ConstantOp::create(builder, op->getLoc(), newType, newValue); partitionMap.map(op->getResult(0), newOp.getResult()); partitionMap.map(op, newOp.getOperation()); } else { // `clone` will populate the mapping of old to new results. (void)builder.clone(*op, partitionMap); } return success(); } }; } // namespace void mlir::arith::registerShardingInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) { ConstantOp::template attachInterface(*ctx); }); }