aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp')
-rw-r--r--mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp89
1 files changed, 86 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index fe3d7c4..9acee5a 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -539,8 +539,9 @@ static bool areValuesCompatibleWithFullReplicationShardings(
if (std::size(values) != std::size(shardings)) {
return false;
}
- return llvm::all_of(llvm::zip(std::forward<ValueRange>(values),
- std::forward<MeshShardingAttrRage>(shardings)),
+ return llvm::all_of(llvm::zip_equal(
+ std::forward<ValueRange>(values),
+ std::forward<MeshShardingAttrRage>(shardings)),
[](auto valueAndSharding) {
return isValueCompatibleWithFullReplicationSharding(
std::get<0>(valueAndSharding),
@@ -563,6 +564,88 @@ void mesh::spmdizeFullyReplicatedOperation(
builder.clone(op, spmdizationMap);
}
+static void updateMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+ SmallVector<std::optional<SmallVector<MeshAxis>>>
+ &meshAxesAssignmentForLoopIterators) {
+ AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
+ unsigned loopIteratorIdx = affineDimExpr.getPosition();
+ if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+ assert(llvm::equal(meshAxesAssignmentForTensorAxis,
+ *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+ } else {
+ meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
+ llvm::to_vector(meshAxesAssignmentForTensorAxis);
+ }
+}
+
+ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<AffineMap> indexingMaps) {
+ SmallVector<std::optional<SmallVector<MeshAxis>>>
+ meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+ SmallVector<MeshShardingAttr> operatorAndResultShardings;
+ operatorAndResultShardings.reserve(operandShardings.size() +
+ resultShardings.size());
+ llvm::append_range(operatorAndResultShardings, operandShardings);
+ for (auto [sharding, affineMap] :
+ llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
+ if (!sharding) {
+ continue;
+ }
+ for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
+ llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
+ updateMeshAxisAssignmentForLoopIterators(
+ meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
+ meshAxisAssignmentForLoopIterators);
+ }
+ // Missing trailing split axes means replication on those tensor dimensions.
+ for (unsigned i = sharding.getSplitAxes().size();
+ i < affineMap.getNumResults(); ++i) {
+ updateMeshAxisAssignmentForLoopIterators(
+ {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
+ }
+ }
+
+ ShardingArray res;
+ llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
+ [](std::optional<SmallVector<MeshAxis>> &axes) {
+ if (!axes) {
+ return SmallVector<MeshAxis>();
+ };
+ return std::move(*axes);
+ });
+ return res;
+}
+
+bool mesh::isAtLeastOneReductionIteratorSharded(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+ for (auto [loopIteratorType, meshAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ if (loopIteratorType == utils::IteratorType::reduction &&
+ !meshAxisAssignment.empty()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+SmallVector<MeshAxis> mesh::getReductionMeshAxes(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+ SmallVector<MeshAxis> meshAxes;
+ for (auto [loopIteratorType, meshAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ if (loopIteratorType == utils::IteratorType::reduction) {
+ llvm::append_range(meshAxes, meshAxisAssignment);
+ }
+ }
+ return meshAxes;
+}
+
void mesh::spmdizeTriviallyShardableOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
@@ -572,7 +655,7 @@ void mesh::spmdizeTriviallyShardableOperation(
Operation *newOp = builder.clone(op, spmdizationMap);
// Set the result types to the sharded counterparts.
for (auto [oldResult, newResult, sharding] :
- llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) {
+ llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
newResult.setType(shardType(newResult.getType(),
getMesh(&op, sharding.getMesh(), symbolTable),
sharding));