aboutsummaryrefslogtreecommitdiff
path: root/mlir/docs/Dialects/Shard.md
blob: eb6ff6150e474512c902003db6c17d41c0ce05f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# 'shard' Dialect

The 'shard' dialect defines a set of attributes, operations, and interfaces for
working with tensor sharding and device communication.

It’s inspired by [GSPMD](*General and Scalable Parallelization for ML Computation Graphs*).

Originally, the dialect was called `mesh`, but it was renamed to better reflect
what it actually does.

[TOC]

## Collective Communication Operations

The 'shard' dialect includes several collective operations that help coordinate
communication between devices arranged in a grid.

If you’re not already familiar with collective operations, [this Wikipedia
article](https://en.wikipedia.org/wiki/Collective_operation) is a good starting
point.

Unlike traditional collectives that are defined in terms of message-passing
between explicit buffers on each process, the collectives in this dialect work
at a higher level. They’re defined in terms of how data moves across the
dimensions of a tensor, and the participating processes are inferred from how
the tensor is sharded - not specified manually.

### Device Groups

Each collective operation runs within a group of devices. You define groups
using the `grid` and `grid_axes` attributes, which describe how to slice the
full device grid into smaller groups.

Devices that have the same coordinates *outside* the listed `grid_axes` belong
to the same group.

Example: Say your device grid is shaped `2×3×4×5`, and you set
`grid_axes = [0, 1]`. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like:

```
{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 }
```

So the groups are identified by the coordinates `(k, m)`, and devices like
`(1, 0, 2, 3)` and `(1, 1, 2, 3)` are in the same group. But `(1, 0, 2, 4)`
is in a different group.

For some collectives (like `all-to-all`), the order of devices in the group
matters. The device order is based on the order of axes in `grid_axes`, from
outermost to innermost.

Example: If `grid_axes = [3, 1]`, then device `(i, 1, k, 0)` comes before
`(i, 0, k, 1)` and `(i, 2, k, 0)`.

### In-group Devices

Some operations (like `broadcast`, `scatter`, and `send`) refer to a specific
device within each group. These in-group devices are identified using their
coordinates over the axes listed in `grid_axes`.

Example: In a 3D grid with `grid_axes = [0, 2]`, an in-group device is specified
as `(i, j)`. If a group is fixed at coordinate `g` on axis 1, then the full
device index would be `(i, g, j)`.

### Purity and Execution Model

Collective operations involve all devices in a group (e.g. `all-gather`,
`all-to-all`) and are considered pure. Operations like `send` and `recv` are not
collective and are not pure.

The execution model assumes SPMD (Single Program, Multiple Data):

* Every process runs the same program.
* At any collective operation, all processes are in sync.

This means compiler optimizations must treat collective ops carefully. For
example, if a collective is removed during optimization, it must be removed from
*every* path and *every* process that would have participated - otherwise, you’ll
get undefined behavior at runtime.

Marking these ops as pure also helps with standard compiler passes like dead
code elimination and common subexpression elimination. It ensures that when the
program is executed, all devices hit the same line of code at the same time
during collectives and so avoid dead-locks.

## Operations

[include "Dialects/ShardOps.md"]

## Attributes

[include "Dialects/ShardAttrs.md"]