Sharding Presets¶
Loadax provides a few common sharding configurations that can be used out of the box. These presets are strong starting points for your sharding configuration, but you can also create your own sharding configurations using JAX's Mesh
and NamedSharding
primitives.
FSDP Sharding¶
The FSDP sharding preset is a simple configuration that shards the model across multiple hosts and devices, and the data across multiple hosts. In FSDP training, the model parameters are split across multiple devices and each model shard recieves unique data. This configuration is useful for training models that are too large to fit on a single device.
from loadax import Dataloader, SimpleDataset
from loadax.sharding.placement import host_to_global_device_array
from loadax.sharding.presets import make_fsdp_mesh_config
dataset = SimpleDataset([jax.array([i]) for i in range(100)])
mesh_config = make_fsdp_mesh_config(
mesh_axis_names=("data", "model"), batch_axis_names="data"
)
mesh = mesh_config.create_device_mesh()
dataloader = Dataloader(dataset, batch_size=8)
# Create your model, optimizer, metrics, and a train_step function
# sharding your model parameters. See your framework's documentation
# for how to configure FSDP or see examples/fsdp.py for an example
# using flax's NNX api!
...
with mesh:
for local_batch in dataloader:
# Stack the batch of arrays into a single array
local_batch = jnp.stack(local_batch)
# Convert the local batch to a global device array
global_batch = host_to_global_device_array(local_batch)
# Use jax.lax.with_sharding_constraint to specify the sharding of the input
sharded_batch = jax.lax.with_sharding_constraint(
global_batch, jax.sharding.PartitionSpec(mesh_rules.data)
)
# let jax.jit handle the movement of data across devices
loss = train_step(model, optimizer, metrics, sharded_batch)
Creates a MeshConfig configured for Fully Sharded Data Parallel (FSDP) training.
- Detects whether the execution is on a single node or multiple nodes.
- Determines the number of devices per node.
- Configures mesh shapes accordingly.
- Applies mesh rules based on the mesh_selector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh_axis_names |
Sequence[str]
|
The names of the mesh axes. |
required |
batch_axis_names |
str | Sequence[str]
|
Subset of mesh axis names over which leaves of the input batch are sharded. Defaults to "data". |
'data'
|
mesh_rules |
List[Tuple[str, MeshShape | HybridMeshShape]] | None
|
Optional list of (regex, MeshShape) pairs to override the default mesh configuration based on the mesh_selector. Defaults to None. |
None
|
mesh_selector |
str
|
A string representing the hardware type or configuration, used to select the appropriate mesh rule. If None, no rules are applied. |
None
|
Returns:
Name | Type | Description |
---|---|---|
MeshConfig |
MeshConfig
|
The configured mesh configuration for FSDP. |
Source code in src/loadax/sharding/presets/fsdp.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
|
DDP Sharding¶
The DDP sharding preset is a simple configuration that replicates the model across multiple devices and each replica recieves unique data. This configuration is ideal for training smaller models when you have multiple devices available.
from loadax import Dataloader, SimpleDataset
from loadax.sharding.placement import host_to_global_device_array
from loadax.sharding.presets import make_ddp_mesh_config
dataset = SimpleDataset([jax.array([i]) for i in range(100)])
mesh_config = make_ddp_mesh_config(
mesh_axis_names=("data",), batch_axis_names="data"
)
mesh = mesh_config.create_device_mesh()
dataloader = Dataloader(dataset, batch_size=8)
# Create your model, optimizer, metrics, and a train_step function
# letting jax.pmap handle replicate the model and sharding the data.
...
with mesh:
for local_batch in dataloader:
# Stack the batch of arrays into a single array
local_batch = jnp.stack(local_batch)
# Convert the local batch to a global device array
global_batch = host_to_global_device_array(local_batch)
# Use jax.lax.with_sharding_constraint to specify the sharding of the input
sharded_batch = jax.lax.with_sharding_constraint(
global_batch, jax.sharding.PartitionSpec(mesh_rules.data)
)
# Use pmap to replicate the computation across all devices
loss = pmap_train_step(model, optimizer, metrics, sharded_batch)
Creates a MeshConfig configured for Data Parallel (DP) training.
- Detects whether the execution is on a single node or multiple nodes.
- Determines the total number of devices across all nodes.
- Configures mesh shapes for data parallelism.
- Applies mesh rules based on the mesh_selector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh_axis_names |
Sequence[str]
|
The names of the mesh axes. Defaults to ("data",). |
('data')
|
batch_axis_names |
str | Sequence[str]
|
Subset of mesh axis names over which leaves of the input batch are sharded. Defaults to "data". |
'data'
|
mesh_rules |
List[Tuple[str, MeshShape | HybridMeshShape]] | None
|
Optional list of (regex, MeshShape) pairs to override the default mesh configuration based on the mesh_selector. Defaults to None. |
None
|
mesh_selector |
str
|
A string representing the hardware type or configuration, used to select the appropriate mesh rule. If None, no rules are applied. |
None
|
Returns:
Name | Type | Description |
---|---|---|
MeshConfig |
MeshConfig
|
The configured mesh configuration for DP. |
Source code in src/loadax/sharding/presets/ddp.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
|