aboutsummaryrefslogtreecommitdiff
path: root/mlir/docs/Dialects/Shard.md
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/docs/Dialects/Shard.md')
-rw-r--r--mlir/docs/Dialects/Shard.md92
1 files changed, 92 insertions, 0 deletions
diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md
new file mode 100644
index 0000000..eb6ff61
--- /dev/null
+++ b/mlir/docs/Dialects/Shard.md
@@ -0,0 +1,92 @@
+# 'shard' Dialect
+
+The 'shard' dialect defines a set of attributes, operations, and interfaces for
+working with tensor sharding and device communication.
+
+It’s inspired by [GSPMD](*General and Scalable Parallelization for ML Computation Graphs*).
+
+Originally, the dialect was called `mesh`, but it was renamed to better reflect
+what it actually does.
+
+[TOC]
+
+## Collective Communication Operations
+
+The 'shard' dialect includes several collective operations that help coordinate
+communication between devices arranged in a grid.
+
+If you’re not already familiar with collective operations, [this Wikipedia
+article](https://en.wikipedia.org/wiki/Collective_operation) is a good starting
+point.
+
+Unlike traditional collectives that are defined in terms of message-passing
+between explicit buffers on each process, the collectives in this dialect work
+at a higher level. They’re defined in terms of how data moves across the
+dimensions of a tensor, and the participating processes are inferred from how
+the tensor is sharded - not specified manually.
+
+### Device Groups
+
+Each collective operation runs within a group of devices. You define groups
+using the `grid` and `grid_axes` attributes, which describe how to slice the
+full device grid into smaller groups.
+
+Devices that have the same coordinates *outside* the listed `grid_axes` belong
+to the same group.
+
+Example: Say your device grid is shaped `2×3×4×5`, and you set
+`grid_axes = [0, 1]`. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like:
+
+```
+{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 }
+```
+
+So the groups are identified by the coordinates `(k, m)`, and devices like
+`(1, 0, 2, 3)` and `(1, 1, 2, 3)` are in the same group. But `(1, 0, 2, 4)`
+is in a different group.
+
+For some collectives (like `all-to-all`), the order of devices in the group
+matters. The device order is based on the order of axes in `grid_axes`, from
+outermost to innermost.
+
+Example: If `grid_axes = [3, 1]`, then device `(i, 1, k, 0)` comes before
+`(i, 0, k, 1)` and `(i, 2, k, 0)`.
+
+### In-group Devices
+
+Some operations (like `broadcast`, `scatter`, and `send`) refer to a specific
+device within each group. These in-group devices are identified using their
+coordinates over the axes listed in `grid_axes`.
+
+Example: In a 3D grid with `grid_axes = [0, 2]`, an in-group device is specified
+as `(i, j)`. If a group is fixed at coordinate `g` on axis 1, then the full
+device index would be `(i, g, j)`.
+
+### Purity and Execution Model
+
+Collective operations involve all devices in a group (e.g. `all-gather`,
+`all-to-all`) and are considered pure. Operations like `send` and `recv` are not
+collective and are not pure.
+
+The execution model assumes SPMD (Single Program, Multiple Data):
+
+* Every process runs the same program.
+* At any collective operation, all processes are in sync.
+
+This means compiler optimizations must treat collective ops carefully. For
+example, if a collective is removed during optimization, it must be removed from
+*every* path and *every* process that would have participated - otherwise, you’ll
+get undefined behavior at runtime.
+
+Marking these ops as pure also helps with standard compiler passes like dead
+code elimination and common subexpression elimination. It ensures that when the
+program is executed, all devices hit the same line of code at the same time
+during collectives and so avoid dead-locks.
+
+## Operations
+
+[include "Dialects/ShardOps.md"]
+
+## Attributes
+
+[include "Dialects/ShardAttrs.md"]