diff options
Diffstat (limited to 'mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp')
-rw-r--r-- | mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 45 |
1 files changed, 28 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 7cbe0de..c4d8b0b 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) { Operation *definingOp = operand.getDefiningOp(); assert(definingOp); ShardOp shardOp = llvm::cast<ShardOp>(definingOp); - assert(shardOp.getAnnotateForUsers()); return shardOp.getShard(); }); return res; @@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) { assert(result.hasOneUse()); Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::cast<ShardOp>(userOp); - assert(!shardOp.getAnnotateForUsers()); return shardOp.getShard(); }); return res; } static LogicalResult -spmdizeOperation(Operation &op, IRMapping &spmdizationMap, +spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { - ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); - if (shardOp) { - if (!shardOp.getAnnotateForUsers()) { - return success(); - } - + Value targetSpmdValue; + + // Check if 2 shard ops are chained. If not there is no need for resharding + // as the source and target shared the same sharding. + ShardOp srcShardOp = + dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp()); + if (!srcShardOp) { + targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand()); + } else { // Insert resharding. - ShardOp srcShardOp = - llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp()); - assert(!srcShardOp.getAnnotateForUsers()); + assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers()); TypedValue<ShapedType> srcSpmdValue = spmdizationMap.lookup(srcShardOp.getOperand()) .cast<TypedValue<ShapedType>>(); - Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, - symbolTableCollection); - assert(!spmdizationMap.contains(shardOp.getResult())); - spmdizationMap.map(shardOp.getResult(), targetSpmdValue); - return success(); + targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, + symbolTableCollection); + } + + assert(!spmdizationMap.contains(shardOp.getResult())); + spmdizationMap.map(shardOp.getResult(), targetSpmdValue); + return success(); +} + +static LogicalResult +spmdizeOperation(Operation &op, IRMapping &spmdizationMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { + ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); + if (shardOp) { + return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, + builder); } SmallVector<Value> spmdizedOperands; |