aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp')
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp13
1 files changed, 13 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index d59b911..cb13ee4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -208,4 +208,17 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
.cast<TypedValue<IndexType>>();
}
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ ImplicitLocOpBuilder &builder) {
+ ResultRange processInGroupMultiIndex =
+ builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
+ Operation::result_range processGroupShape =
+ builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
+ OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
+ llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+ llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+ return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
+}
+
} // namespace mlir::mesh