aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
diff options
context:
space:
mode:
authorFlorian Mayer <fmayer@google.com>2024-02-23 11:31:14 -0800
committerFlorian Mayer <fmayer@google.com>2024-02-23 11:31:14 -0800
commit886b4bc97b0ed5a5e041a0117a584182fc7989c1 (patch)
tree43cdc0e15e12c298c09251dda38e834e7e778049 /mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
parentaf8afe08ee20a04b2ccb363cac66aa02cfaecd02 (diff)
parent8d536f83545f071948888983e2db25ce23a8302d (diff)
downloadllvm-886b4bc97b0ed5a5e041a0117a584182fc7989c1.zip
llvm-886b4bc97b0ed5a5e041a0117a584182fc7989c1.tar.gz
llvm-886b4bc97b0ed5a5e041a0117a584182fc7989c1.tar.bz2
Created using spr 1.3.4
Diffstat (limited to 'mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp')
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp45
1 files changed, 28 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 7cbe0de..c4d8b0b 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
Operation *definingOp = operand.getDefiningOp();
assert(definingOp);
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
- assert(shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
@@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
assert(result.hasOneUse());
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
- assert(!shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
}
static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
- if (shardOp) {
- if (!shardOp.getAnnotateForUsers()) {
- return success();
- }
-
+ Value targetSpmdValue;
+
+ // Check if 2 shard ops are chained. If not there is no need for resharding
+ // as the source and target shared the same sharding.
+ ShardOp srcShardOp =
+ dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
+ if (!srcShardOp) {
+ targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
+ } else {
// Insert resharding.
- ShardOp srcShardOp =
- llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
- assert(!srcShardOp.getAnnotateForUsers());
+ assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
TypedValue<ShapedType> srcSpmdValue =
spmdizationMap.lookup(srcShardOp.getOperand())
.cast<TypedValue<ShapedType>>();
- Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
- symbolTableCollection);
- assert(!spmdizationMap.contains(shardOp.getResult()));
- spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
- return success();
+ targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+ symbolTableCollection);
+ }
+
+ assert(!spmdizationMap.contains(shardOp.getResult()));
+ spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+ return success();
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+ if (shardOp) {
+ return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
+ builder);
}
SmallVector<Value> spmdizedOperands;