diff options
Diffstat (limited to 'mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp')
-rw-r--r-- | mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index d59b911..cb13ee4 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -208,4 +208,17 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes, .cast<TypedValue<IndexType>>(); } +TypedValue<IndexType> createProcessLinearIndex(StringRef mesh, + ArrayRef<MeshAxis> meshAxes, + ImplicitLocOpBuilder &builder) { + ResultRange processInGroupMultiIndex = + builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults(); + Operation::result_range processGroupShape = + builder.create<MeshShapeOp>(mesh, meshAxes).getResult(); + OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( + llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), + llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); + return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>()); +} + } // namespace mlir::mesh |