aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Mesh/IR/MeshOps.cpp')
-rw-r--r--mlir/lib/Dialect/Mesh/IR/MeshOps.cpp213
1 files changed, 160 insertions, 53 deletions
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index d276755..de4f58d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -9,8 +9,10 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
@@ -59,8 +61,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
return vec;
}
-using MeshAxis = int16_t;
-
namespace {
struct DimensionSize {
@@ -114,6 +114,56 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Mesh utilities
//===----------------------------------------------------------------------===//
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTable) {
+ mesh::ClusterOp mesh =
+ symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+ if (!mesh) {
+ return op->emitError() << "Undefined required mesh symbol \""
+ << meshSymbol.getValue() << "\".";
+ }
+
+ return mesh;
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+ if (begin == end) {
+ return true;
+ }
+ It next = std::next(begin);
+ if (next == end) {
+ return true;
+ }
+ for (; next != end; ++next, ++begin) {
+ if (*begin == *next) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+ ClusterOp mesh) {
+ SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+ llvm::sort(sorted);
+ if (!isUnique(sorted.begin(), sorted.end())) {
+ return emitError(loc) << "Mesh axes contains duplicate elements.";
+ }
+
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : axes) {
+ if (axis >= rank || axis < 0) {
+ return emitError(loc)
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+ << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
bool mesh::isReductionLoop(IteratorType iType) {
return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
}
@@ -173,7 +223,45 @@ SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
}
//===----------------------------------------------------------------------===//
-// mesh.shard op
+// mesh.cluster_shape op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ return failure();
+ }
+
+ size_t expectedResultsCount =
+ getAxes().empty() ? mesh->getRank() : getAxes().size();
+ if (getResult().size() != expectedResultsCount) {
+ return emitError() << "Unexpected number of results " << getResult().size()
+ << ". Expected " << expectedResultsCount << ".";
+ }
+
+ return success();
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr());
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+ MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard attr
//===----------------------------------------------------------------------===//
LogicalResult
@@ -205,6 +293,75 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+bool MeshShardingAttr::operator==(Attribute rhs) const {
+ MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
+ return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
+}
+
+bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
+ if (getCluster() != rhs.getCluster() ||
+ getPartialAxes() != rhs.getPartialAxes()) {
+ return false;
+ }
+
+ if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+ return false;
+ }
+
+ auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
+ if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
+ getSplitAxes().begin() + minSize),
+ llvm::make_range(rhs.getSplitAxes().begin(),
+ rhs.getSplitAxes().begin() + minSize))) {
+ return false;
+ }
+
+ return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
+ getSplitAxes().end()),
+ std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+ llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
+ rhs.getSplitAxes().end()),
+ std::mem_fn(&DenseI32ArrayAttr::empty));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.process_index op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ return failure();
+ }
+
+ size_t expectedResultsCount =
+ getAxes().empty() ? mesh->getRank() : getAxes().size();
+ if (getResult().size() != expectedResultsCount) {
+ return emitError() << "Unexpected number of results " << getResult().size()
+ << ". Expected " << expectedResultsCount << ".";
+ }
+
+ return success();
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr());
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+ MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -258,56 +415,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
return success();
}
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTable) {
- mesh::ClusterOp mesh =
- symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
- if (!mesh) {
- return op->emitError() << "Undefined required mesh symbol \""
- << meshSymbol.getValue() << "\".";
- }
-
- return mesh;
-}
-
-template <typename It>
-bool isUnique(It begin, It end) {
- if (begin == end) {
- return true;
- }
- It next = std::next(begin);
- if (next == end) {
- return true;
- }
- for (; next != end; ++next, ++begin) {
- if (*begin == *next) {
- return false;
- }
- }
- return true;
-}
-
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
- ClusterOp mesh) {
- SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
- llvm::sort(sorted);
- if (!isUnique(sorted.begin(), sorted.end())) {
- return emitError(loc) << "Mesh axes contains duplicate elements.";
- }
-
- MeshAxis rank = mesh.getRank();
- for (auto axis : axes) {
- if (axis >= rank || axis < 0) {
- return emitError(loc)
- << "0-based mesh axis index " << axis
- << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
- << "\" is of rank " << rank << ".";
- }
- }
-
- return success();
-}
-
template <typename Op>
static FailureOr<ClusterOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {