diff options
Diffstat (limited to 'mlir/lib/Dialect/Mesh/IR/MeshOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 213 |
1 files changed, 160 insertions, 53 deletions
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index d276755..de4f58d 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -9,8 +9,10 @@ #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" @@ -59,8 +61,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) { return vec; } -using MeshAxis = int16_t; - namespace { struct DimensionSize { @@ -114,6 +114,56 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, // Mesh utilities //===----------------------------------------------------------------------===// +static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, + SymbolTableCollection &symbolTable) { + mesh::ClusterOp mesh = + symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol); + if (!mesh) { + return op->emitError() << "Undefined required mesh symbol \"" + << meshSymbol.getValue() << "\"."; + } + + return mesh; +} + +template <typename It> +bool isUnique(It begin, It end) { + if (begin == end) { + return true; + } + It next = std::next(begin); + if (next == end) { + return true; + } + for (; next != end; ++next, ++begin) { + if (*begin == *next) { + return false; + } + } + return true; +} + +static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, + ClusterOp mesh) { + SmallVector<MeshAxis> sorted = llvm::to_vector(axes); + llvm::sort(sorted); + if (!isUnique(sorted.begin(), sorted.end())) { + return emitError(loc) << "Mesh axes contains duplicate elements."; + } + + MeshAxis rank = mesh.getRank(); + for (auto axis : axes) { + if (axis >= rank || axis < 0) { + return emitError(loc) + << "0-based mesh axis index " << axis + << " is out of bounds. The referenced mesh \"" << mesh.getSymName() + << "\" is of rank " << rank << "."; + } + } + + return success(); +} + bool mesh::isReductionLoop(IteratorType iType) { return iType != IteratorType::Parallel && iType != IteratorType::Invalid; } @@ -173,7 +223,45 @@ SmallVector<int64_t> ClusterOp::canonicalDimSizes() { } //===----------------------------------------------------------------------===// -// mesh.shard op +// mesh.cluster_shape op +//===----------------------------------------------------------------------===// + +LogicalResult +ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable); + if (failed(mesh)) { + return failure(); + } + if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { + return failure(); + } + + size_t expectedResultsCount = + getAxes().empty() ? mesh->getRank() : getAxes().size(); + if (getResult().size() != expectedResultsCount) { + return emitError() << "Unexpected number of results " << getResult().size() + << ". Expected " << expectedResultsCount << "."; + } + + return success(); +} + +void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + ClusterOp mesh) { + build(odsBuilder, odsState, + SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), + mesh.getSymName(), MeshAxesAttr()); +} + +void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef mesh, ArrayRef<MeshAxis> axes) { + build(odsBuilder, odsState, + SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, + MeshAxesAttr::get(odsBuilder.getContext(), axes)); +} + +//===----------------------------------------------------------------------===// +// mesh.shard attr //===----------------------------------------------------------------------===// LogicalResult @@ -205,6 +293,75 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError, return success(); } +bool MeshShardingAttr::operator==(Attribute rhs) const { + MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>(); + return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr; +} + +bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const { + if (getCluster() != rhs.getCluster() || + getPartialAxes() != rhs.getPartialAxes()) { + return false; + } + + if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) { + return false; + } + + auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size()); + if (!llvm::equal(llvm::make_range(getSplitAxes().begin(), + getSplitAxes().begin() + minSize), + llvm::make_range(rhs.getSplitAxes().begin(), + rhs.getSplitAxes().begin() + minSize))) { + return false; + } + + return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize, + getSplitAxes().end()), + std::mem_fn(&DenseI32ArrayAttr::empty)) && + llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize, + rhs.getSplitAxes().end()), + std::mem_fn(&DenseI32ArrayAttr::empty)); +} + +//===----------------------------------------------------------------------===// +// mesh.process_index op +//===----------------------------------------------------------------------===// + +LogicalResult +ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable); + if (failed(mesh)) { + return failure(); + } + if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { + return failure(); + } + + size_t expectedResultsCount = + getAxes().empty() ? mesh->getRank() : getAxes().size(); + if (getResult().size() != expectedResultsCount) { + return emitError() << "Unexpected number of results " << getResult().size() + << ". Expected " << expectedResultsCount << "."; + } + + return success(); +} + +void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, + ClusterOp mesh) { + build(odsBuilder, odsState, + SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), + mesh.getSymName(), MeshAxesAttr()); +} + +void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef mesh, ArrayRef<MeshAxis> axes) { + build(odsBuilder, odsState, + SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, + MeshAxesAttr::get(odsBuilder.getContext(), axes)); +} + //===----------------------------------------------------------------------===// // collective communication ops //===----------------------------------------------------------------------===// @@ -258,56 +415,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, return success(); } -static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, - SymbolTableCollection &symbolTable) { - mesh::ClusterOp mesh = - symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol); - if (!mesh) { - return op->emitError() << "Undefined required mesh symbol \"" - << meshSymbol.getValue() << "\"."; - } - - return mesh; -} - -template <typename It> -bool isUnique(It begin, It end) { - if (begin == end) { - return true; - } - It next = std::next(begin); - if (next == end) { - return true; - } - for (; next != end; ++next, ++begin) { - if (*begin == *next) { - return false; - } - } - return true; -} - -static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, - ClusterOp mesh) { - SmallVector<MeshAxis> sorted = llvm::to_vector(axes); - llvm::sort(sorted); - if (!isUnique(sorted.begin(), sorted.end())) { - return emitError(loc) << "Mesh axes contains duplicate elements."; - } - - MeshAxis rank = mesh.getRank(); - for (auto axis : axes) { - if (axis >= rank || axis < 0) { - return emitError(loc) - << "0-based mesh axis index " << axis - << " is out of bounds. The referenced mesh \"" << mesh.getSymName() - << "\" is of rank " << rank << "."; - } - } - - return success(); -} - template <typename Op> static FailureOr<ClusterOp> getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { |