Skip to content

Sharding Presets

Loadax provides a few common sharding configurations that can be used out of the box. These presets are strong starting points for your sharding configuration, but you can also create your own sharding configurations using JAX's Mesh and NamedSharding primitives.

FSDP Sharding

The FSDP sharding preset is a simple configuration that shards the model across multiple hosts and devices, and the data across multiple hosts. In FSDP training, the model parameters are split across multiple devices and each model shard recieves unique data. This configuration is useful for training models that are too large to fit on a single device.

Creating an FSDP sharding preset
from loadax import Dataloader, SimpleDataset
from loadax.sharding.placement import host_to_global_device_array
from loadax.sharding.presets import make_fsdp_mesh_config

dataset = SimpleDataset([jax.array([i]) for i in range(100)])

mesh_config = make_fsdp_mesh_config(
        mesh_axis_names=("data", "model"), batch_axis_names="data"
)
mesh = mesh_config.create_device_mesh()

dataloader = Dataloader(dataset, batch_size=8)

# Create your model, optimizer, metrics, and a train_step function
# sharding your model parameters. See your framework's documentation
# for how to configure FSDP or see examples/fsdp.py for an example 
# using flax's NNX api!
...

with mesh:
    for local_batch in dataloader:
        # Stack the batch of arrays into a single array
        local_batch = jnp.stack(local_batch)

        # Convert the local batch to a global device array
        global_batch = host_to_global_device_array(local_batch)

        # Use jax.lax.with_sharding_constraint to specify the sharding of the input
        sharded_batch = jax.lax.with_sharding_constraint(
            global_batch, jax.sharding.PartitionSpec(mesh_rules.data)
        )

        # let jax.jit handle the movement of data across devices
        loss = train_step(model, optimizer, metrics, sharded_batch)

Creates a MeshConfig configured for Fully Sharded Data Parallel (FSDP) training.

  • Detects whether the execution is on a single node or multiple nodes.
  • Determines the number of devices per node.
  • Configures mesh shapes accordingly.
  • Applies mesh rules based on the mesh_selector.

Parameters:

Name Type Description Default
mesh_axis_names Sequence[str]

The names of the mesh axes.

required
batch_axis_names str | Sequence[str]

Subset of mesh axis names over which leaves of the input batch are sharded. Defaults to "data".

'data'
mesh_rules List[Tuple[str, MeshShape | HybridMeshShape]] | None

Optional list of (regex, MeshShape) pairs to override the default mesh configuration based on the mesh_selector. Defaults to None.

None
mesh_selector str

A string representing the hardware type or configuration, used to select the appropriate mesh rule. If None, no rules are applied.

None

Returns:

Name Type Description
MeshConfig MeshConfig

The configured mesh configuration for FSDP.

Source code in src/loadax/sharding/presets/fsdp.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def make_fsdp_mesh_config(
    mesh_axis_names: Sequence[str],
    batch_axis_names: str | Sequence[str] = "data",
    mesh_rules: list[tuple[str, MeshShape | HybridMeshShape | None]] | None = None,
    mesh_selector: str | None = None,
) -> MeshConfig:
    """Creates a MeshConfig configured for Fully Sharded Data Parallel (FSDP) training.

    - Detects whether the execution is on a single node or multiple nodes.
    - Determines the number of devices per node.
    - Configures mesh shapes accordingly.
    - Applies mesh rules based on the mesh_selector.

    Args:
        mesh_axis_names (Sequence[str]): The names of the mesh axes.
        batch_axis_names (str | Sequence[str], optional): Subset of mesh axis names
            over which leaves of the input batch are sharded. Defaults to "data".
        mesh_rules (List[Tuple[str, MeshShape | HybridMeshShape]] | None, optional):
            Optional list of (regex, MeshShape) pairs to override the default mesh
            configuration based on the mesh_selector. Defaults to None.
        mesh_selector (str, optional): A string representing the hardware type or
            configuration, used to select the appropriate mesh rule. If None, no rules
            are applied.

    Returns:
        MeshConfig: The configured mesh configuration for FSDP.
    """
    # Initialize default mesh_shape as None
    default_mesh_shape: MeshShape | HybridMeshShape | None = None

    # Apply mesh rules if provided
    if mesh_rules and mesh_selector:
        for pattern, shape in mesh_rules:
            if re.match(pattern, mesh_selector):
                if shape is None:
                    raise ValueError(
                        f"Mesh shape for pattern '{pattern}' cannot be None."
                    )
                default_mesh_shape = shape
                print(f"Mesh rule matched: pattern='{pattern}', applying shape={shape}")
                break

    # If no mesh_rule matched or no rules provided, infer mesh_shape
    if default_mesh_shape is None:
        # Total number of nodes participating in the computation
        num_nodes = jax.process_count()

        # Number of devices (e.g., GPUs) available on the current node
        # we assume all nodes are homogeneous (jax assumes this as well)
        devices_per_node = len(jax.local_devices())

        if num_nodes < 1:
            raise ValueError(f"Invalid number of nodes: {num_nodes}. Must be >= 1.")

        if devices_per_node < 1:
            raise ValueError(f"""Invalid number of devices per node: {devices_per_node}.
            Must be >= 1.""")

        # Configure DCN mesh shape
        if num_nodes == 1:
            # Single-node setup
            dcn_mesh_shape = tuple([1] * len(mesh_axis_names))
        else:
            # Multi-node setup: assuming data or pipeline parallelism across nodes
            # Identify the first axis to partition (first non-singleton)
            # Here, we assume that the first axis is the one to be partitioned
            # Modify this logic based on your specific parallelism strategy
            dcn_mesh_shape = tuple([1] * len(mesh_axis_names))

            # For simplicity, let's partition along the first axis
            dcn_mesh_shape = (num_nodes,) + dcn_mesh_shape[1:]

        # Configure ICI (Intra-Component Interconnect) mesh shape
        ici_mesh_shape = list([1] * len(mesh_axis_names))
        # Assume model parallelism is on the last axis
        ici_mesh_shape[-1] = devices_per_node

        hybrid_mesh_shape = HybridMeshShape(
            ici_mesh_shape=tuple(ici_mesh_shape), dcn_mesh_shape=tuple(dcn_mesh_shape)
        )

        default_mesh_shape = hybrid_mesh_shape

    # Instantiate MeshConfig
    mesh_config = MeshConfig(
        mesh_shape=default_mesh_shape,
        mesh_axis_names=mesh_axis_names,
        batch_axis_names=batch_axis_names,
        mesh_rules=mesh_rules,  # type: ignore
    )

    return mesh_config

DDP Sharding

The DDP sharding preset is a simple configuration that replicates the model across multiple devices and each replica recieves unique data. This configuration is ideal for training smaller models when you have multiple devices available.

DDP Sharding
from loadax import Dataloader, SimpleDataset
from loadax.sharding.placement import host_to_global_device_array
from loadax.sharding.presets import make_ddp_mesh_config

dataset = SimpleDataset([jax.array([i]) for i in range(100)])

mesh_config = make_ddp_mesh_config(
        mesh_axis_names=("data",), batch_axis_names="data"
)
mesh = mesh_config.create_device_mesh()

dataloader = Dataloader(dataset, batch_size=8)

# Create your model, optimizer, metrics, and a train_step function
# letting jax.pmap handle replicate the model and sharding the data.
...

with mesh:
    for local_batch in dataloader:
        # Stack the batch of arrays into a single array
        local_batch = jnp.stack(local_batch)

        # Convert the local batch to a global device array
        global_batch = host_to_global_device_array(local_batch)

        # Use jax.lax.with_sharding_constraint to specify the sharding of the input
        sharded_batch = jax.lax.with_sharding_constraint(
            global_batch, jax.sharding.PartitionSpec(mesh_rules.data)
        )

        # Use pmap to replicate the computation across all devices
        loss = pmap_train_step(model, optimizer, metrics, sharded_batch)

Creates a MeshConfig configured for Data Parallel (DP) training.

  • Detects whether the execution is on a single node or multiple nodes.
  • Determines the total number of devices across all nodes.
  • Configures mesh shapes for data parallelism.
  • Applies mesh rules based on the mesh_selector.

Parameters:

Name Type Description Default
mesh_axis_names Sequence[str]

The names of the mesh axes. Defaults to ("data",).

('data')
batch_axis_names str | Sequence[str]

Subset of mesh axis names over which leaves of the input batch are sharded. Defaults to "data".

'data'
mesh_rules List[Tuple[str, MeshShape | HybridMeshShape]] | None

Optional list of (regex, MeshShape) pairs to override the default mesh configuration based on the mesh_selector. Defaults to None.

None
mesh_selector str

A string representing the hardware type or configuration, used to select the appropriate mesh rule. If None, no rules are applied.

None

Returns:

Name Type Description
MeshConfig MeshConfig

The configured mesh configuration for DP.

Source code in src/loadax/sharding/presets/ddp.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def make_ddp_mesh_config(
    mesh_axis_names: Sequence[str] = ("data",),
    batch_axis_names: str | Sequence[str] = "data",
    mesh_rules: list[tuple[str, MeshShape | HybridMeshShape | None]] | None = None,
    mesh_selector: str | None = None,
) -> MeshConfig:
    """Creates a MeshConfig configured for Data Parallel (DP) training.

    - Detects whether the execution is on a single node or multiple nodes.
    - Determines the total number of devices across all nodes.
    - Configures mesh shapes for data parallelism.
    - Applies mesh rules based on the mesh_selector.

    Args:
        mesh_axis_names (Sequence[str], optional): The names of the mesh axes.
            Defaults to ("data",).
        batch_axis_names (str | Sequence[str], optional): Subset of mesh axis names
            over which leaves of the input batch are sharded. Defaults to "data".
        mesh_rules (List[Tuple[str, MeshShape | HybridMeshShape]] | None, optional):
            Optional list of (regex, MeshShape) pairs to override the default mesh
            configuration based on the mesh_selector. Defaults to None.
        mesh_selector (str, optional): A string representing the hardware type or
            configuration, used to select the appropriate mesh rule. If None, no rules
            are applied.

    Returns:
        MeshConfig: The configured mesh configuration for DP.
    """
    # Initialize default mesh_shape as None
    default_mesh_shape: MeshShape | HybridMeshShape | None = None

    # Apply mesh rules if provided
    if mesh_rules and mesh_selector:
        for pattern, shape in mesh_rules:
            if re.match(pattern, mesh_selector):
                if shape is None:
                    raise ValueError(
                        f"Mesh shape for pattern '{pattern}' cannot be None."
                    )
                default_mesh_shape = shape
                print(f"Mesh rule matched: pattern='{pattern}', applying shape={shape}")
                break

    # If no mesh_rule matched or no rules provided, infer mesh_shape
    if default_mesh_shape is None:
        # Total number of nodes participating in the computation
        num_nodes = jax.process_count()

        # Number of devices (e.g., GPUs) available on the current node
        # we assume all nodes are homogeneous (jax assumes this as well)
        devices_per_node = len(jax.local_devices())

        if num_nodes < 1:
            raise ValueError(f"Invalid number of nodes: {num_nodes}. Must be >= 1.")

        if devices_per_node < 1:
            raise ValueError(f"""Invalid number of devices per node: {devices_per_node}.
            Must be >= 1.""")

        # For DP, we use a single axis for data parallelism
        ici_mesh_shape = (devices_per_node,)
        dcn_mesh_shape = (num_nodes,)

        hybrid_mesh_shape = HybridMeshShape(
            ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape
        )

        default_mesh_shape = hybrid_mesh_shape

    # Instantiate MeshConfig
    mesh_config = MeshConfig(
        mesh_shape=default_mesh_shape,
        mesh_axis_names=mesh_axis_names,
        batch_axis_names=batch_axis_names,
        mesh_rules=mesh_rules,  # type: ignore
    )

    return mesh_config