diff options
author | Florian Mayer <fmayer@google.com> | 2024-02-23 11:31:14 -0800 |
---|---|---|
committer | Florian Mayer <fmayer@google.com> | 2024-02-23 11:31:14 -0800 |
commit | 886b4bc97b0ed5a5e041a0117a584182fc7989c1 (patch) | |
tree | 43cdc0e15e12c298c09251dda38e834e7e778049 /mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | |
parent | af8afe08ee20a04b2ccb363cac66aa02cfaecd02 (diff) | |
parent | 8d536f83545f071948888983e2db25ce23a8302d (diff) | |
download | llvm-886b4bc97b0ed5a5e041a0117a584182fc7989c1.zip llvm-886b4bc97b0ed5a5e041a0117a584182fc7989c1.tar.gz llvm-886b4bc97b0ed5a5e041a0117a584182fc7989c1.tar.bz2 |
Created using spr 1.3.4
Diffstat (limited to 'mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp')
-rw-r--r-- | mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 45 |
1 files changed, 28 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 7cbe0de..c4d8b0b 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) { Operation *definingOp = operand.getDefiningOp(); assert(definingOp); ShardOp shardOp = llvm::cast<ShardOp>(definingOp); - assert(shardOp.getAnnotateForUsers()); return shardOp.getShard(); }); return res; @@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) { assert(result.hasOneUse()); Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::cast<ShardOp>(userOp); - assert(!shardOp.getAnnotateForUsers()); return shardOp.getShard(); }); return res; } static LogicalResult -spmdizeOperation(Operation &op, IRMapping &spmdizationMap, +spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { - ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); - if (shardOp) { - if (!shardOp.getAnnotateForUsers()) { - return success(); - } - + Value targetSpmdValue; + + // Check if 2 shard ops are chained. If not there is no need for resharding + // as the source and target shared the same sharding. + ShardOp srcShardOp = + dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp()); + if (!srcShardOp) { + targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand()); + } else { // Insert resharding. - ShardOp srcShardOp = - llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp()); - assert(!srcShardOp.getAnnotateForUsers()); + assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers()); TypedValue<ShapedType> srcSpmdValue = spmdizationMap.lookup(srcShardOp.getOperand()) .cast<TypedValue<ShapedType>>(); - Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, - symbolTableCollection); - assert(!spmdizationMap.contains(shardOp.getResult())); - spmdizationMap.map(shardOp.getResult(), targetSpmdValue); - return success(); + targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, + symbolTableCollection); + } + + assert(!spmdizationMap.contains(shardOp.getResult())); + spmdizationMap.map(shardOp.getResult(), targetSpmdValue); + return success(); +} + +static LogicalResult +spmdizeOperation(Operation &op, IRMapping &spmdizationMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { + ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); + if (shardOp) { + return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, + builder); } SmallVector<Value> spmdizedOperands; |