diff options
Diffstat (limited to 'mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp')
-rw-r--r-- | mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp | 89 |
1 files changed, 86 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index fe3d7c4..9acee5a 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -539,8 +539,9 @@ static bool areValuesCompatibleWithFullReplicationShardings( if (std::size(values) != std::size(shardings)) { return false; } - return llvm::all_of(llvm::zip(std::forward<ValueRange>(values), - std::forward<MeshShardingAttrRage>(shardings)), + return llvm::all_of(llvm::zip_equal( + std::forward<ValueRange>(values), + std::forward<MeshShardingAttrRage>(shardings)), [](auto valueAndSharding) { return isValueCompatibleWithFullReplicationSharding( std::get<0>(valueAndSharding), @@ -563,6 +564,88 @@ void mesh::spmdizeFullyReplicatedOperation( builder.clone(op, spmdizationMap); } +static void updateMeshAxisAssignmentForLoopIterators( + ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, + SmallVector<std::optional<SmallVector<MeshAxis>>> + &meshAxesAssignmentForLoopIterators) { + AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr); + unsigned loopIteratorIdx = affineDimExpr.getPosition(); + if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { + assert(llvm::equal(meshAxesAssignmentForTensorAxis, + *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); + } else { + meshAxesAssignmentForLoopIterators[loopIteratorIdx] = + llvm::to_vector(meshAxesAssignmentForTensorAxis); + } +} + +ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<AffineMap> indexingMaps) { + SmallVector<std::optional<SmallVector<MeshAxis>>> + meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); + SmallVector<MeshShardingAttr> operatorAndResultShardings; + operatorAndResultShardings.reserve(operandShardings.size() + + resultShardings.size()); + llvm::append_range(operatorAndResultShardings, operandShardings); + for (auto [sharding, affineMap] : + llvm::zip_equal(operatorAndResultShardings, indexingMaps)) { + if (!sharding) { + continue; + } + for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : + llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { + updateMeshAxisAssignmentForLoopIterators( + meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, + meshAxisAssignmentForLoopIterators); + } + // Missing trailing split axes means replication on those tensor dimensions. + for (unsigned i = sharding.getSplitAxes().size(); + i < affineMap.getNumResults(); ++i) { + updateMeshAxisAssignmentForLoopIterators( + {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators); + } + } + + ShardingArray res; + llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res), + [](std::optional<SmallVector<MeshAxis>> &axes) { + if (!axes) { + return SmallVector<MeshAxis>(); + }; + return std::move(*axes); + }); + return res; +} + +bool mesh::isAtLeastOneReductionIteratorSharded( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { + for (auto [loopIteratorType, meshAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + if (loopIteratorType == utils::IteratorType::reduction && + !meshAxisAssignment.empty()) { + return true; + } + } + return false; +} + +SmallVector<MeshAxis> mesh::getReductionMeshAxes( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { + SmallVector<MeshAxis> meshAxes; + for (auto [loopIteratorType, meshAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + if (loopIteratorType == utils::IteratorType::reduction) { + llvm::append_range(meshAxes, meshAxisAssignment); + } + } + return meshAxes; +} + void mesh::spmdizeTriviallyShardableOperation( Operation &op, ArrayRef<Value> spmdizedOperands, ArrayRef<MeshShardingAttr> operandShardings, @@ -572,7 +655,7 @@ void mesh::spmdizeTriviallyShardableOperation( Operation *newOp = builder.clone(op, spmdizationMap); // Set the result types to the sharded counterparts. for (auto [oldResult, newResult, sharding] : - llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) { + llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { newResult.setType(shardType(newResult.getType(), getMesh(&op, sharding.getMesh(), symbolTable), sharding)); |