From 4f7ab789bf43b49914815bdf4e4c3703f92e781d Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 22 Feb 2024 11:06:14 -0800 Subject: [mlir][mesh] add support in spmdization for incomplete sharding annotations (#82442) Don't require that `mesh.shard` operations come in pairs. If there is only a single `mesh.shard` operation we assume that the producer result and consumer operand have the same sharding. --- mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 45 +++++++++++++++--------- 1 file changed, 28 insertions(+), 17 deletions(-) (limited to 'mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp') diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 7cbe0de..c4d8b0b 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -593,7 +593,6 @@ static SmallVector getOperandShardings(Operation &op) { Operation *definingOp = operand.getDefiningOp(); assert(definingOp); ShardOp shardOp = llvm::cast(definingOp); - assert(shardOp.getAnnotateForUsers()); return shardOp.getShard(); }); return res; @@ -615,34 +614,46 @@ static SmallVector getResultShardings(Operation &op) { assert(result.hasOneUse()); Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::cast(userOp); - assert(!shardOp.getAnnotateForUsers()); return shardOp.getShard(); }); return res; } static LogicalResult -spmdizeOperation(Operation &op, IRMapping &spmdizationMap, +spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { - ShardOp shardOp = llvm::dyn_cast(op); - if (shardOp) { - if (!shardOp.getAnnotateForUsers()) { - return success(); - } - + Value targetSpmdValue; + + // Check if 2 shard ops are chained. If not there is no need for resharding + // as the source and target shared the same sharding. + ShardOp srcShardOp = + dyn_cast_or_null(shardOp.getOperand().getDefiningOp()); + if (!srcShardOp) { + targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand()); + } else { // Insert resharding. - ShardOp srcShardOp = - llvm::cast(shardOp.getOperand().getDefiningOp()); - assert(!srcShardOp.getAnnotateForUsers()); + assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers()); TypedValue srcSpmdValue = spmdizationMap.lookup(srcShardOp.getOperand()) .cast>(); - Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, - symbolTableCollection); - assert(!spmdizationMap.contains(shardOp.getResult())); - spmdizationMap.map(shardOp.getResult(), targetSpmdValue); - return success(); + targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, + symbolTableCollection); + } + + assert(!spmdizationMap.contains(shardOp.getResult())); + spmdizationMap.map(shardOp.getResult(), targetSpmdValue); + return success(); +} + +static LogicalResult +spmdizeOperation(Operation &op, IRMapping &spmdizationMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { + ShardOp shardOp = llvm::dyn_cast(op); + if (shardOp) { + return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, + builder); } SmallVector spmdizedOperands; -- cgit v1.1