Skip to content

Mesh

Loadax has a specialized mesh definition called a HybridMeshShape. This enables you to define your mesh with two contexts, the global topology (inter-node) and the local topology (intra-node). This makes it easier to reason about the placement of data across the network.

To use loadax's mesh abstractions, you need to define a MeshConfig. This config tells loadax how to split up the work across the mesh. A MeshConfig must specify a mesh_shape, which is a HybridMeshShape and some annotations about the mesh axes.

Typically you should not need to define a MeshConfig yourself. Instead you can rely on Loadax's automatic mesh discovery to find a good mesh shape for your cluster for a given parallelization strategy. See the Presets section for more details.

A mesh shape for hybrid (i.e., ICI and DCN) parallelism.

For example, with mesh axes (data, model): - Pure fsdp on a v4-8:

HybridMeshShape(ici_mesh_shape=(1, 4), dcn_mesh_shape=(1, 1))
- Two-way data parallelism over 2 H100 nodes, and fsdp within-node:
HybridMeshShape(ici_mesh_shape=(1, 8), dcn_mesh_shape=(2, 1))

ici_mesh_shape instance-attribute

ici_mesh_shape: MeshShape

The mesh shape for the ICI (inner-core) parallelism. This represents the host-local sharding.

dcn_mesh_shape instance-attribute

dcn_mesh_shape: MeshShape

The mesh shape for the DCN (data-parallel) parallelism. This represents the inter-host sharding.

Sharding Mesh Configuration.

mesh_shape instance-attribute

mesh_shape: MeshShape | HybridMeshShape

If specified as a MeshShape, must have the same length as mesh_axis_names. Implicitly, this treats the mesh shape as the ICI mesh shape; we default to a DCN mesh shape that partitions the first non-singleton axis across granules (e.g. TPU slices or GPU nodes). If all axes are singletons, this implies a single-granule environment and therefore an all-1's DCN mesh shape.

As an example on 2 H100 nodes, for mesh axes (pipeline, data, model) and a MeshShape of (1, 2, 8), we break the "data" axis across DCN -- this produces a DCN mesh shape (1, 2, 1) and an ICI mesh shape (1, 1, 8), i.e. 2-way data-parallelism across DCN, and 8-way model parallelism within-node (e.g. NVLink). If instead the MeshShape is provided as (2, 1, 8), we break along the "pipeline" axis, producing a DCN mesh shape of (2, 1, 1) and ICI mesh shape (1, 1, 8) for 2-way pipeline-parallelism across DCN and 8-way model parallelism within-node.

If specified as a HybridMeshShape, each member must have the same length as mesh_axis_names.

Use mesh_rules to set different mesh shapes depending on the hardware platform.

mesh_axis_names instance-attribute

mesh_axis_names: Sequence[str]

The mesh axis names. The names can be referenced in ParameterSpec.mesh_axes.

batch_axis_names class-attribute instance-attribute

batch_axis_names: str | Sequence[str] = 'data'

Subset of mesh axis names over which leaves of the input batch are sharded.

mesh_rules class-attribute instance-attribute

mesh_rules: Sequence[tuple[str, MeshShape | None]] | None = None

An optional list of (regex, MeshShape) pairs to override the default mesh configuration.

This is useful when we want to use different mesh shapes depending on the device types (e.g., 'tpu-v4-128' vs. 'gpu-p4de.24xlarge-32').

Given a mesh_selector string (usually representing the device type and set by user's launch script), the first rule that with a regex that matches the selector will determine the mesh shape.

If no rule matches, the default mesh configuration will be used.

hosts property

hosts: int

Returns the number of hosts in the mesh.

This is just a simple wrapper around jax.process_count().

host_id property

host_id: int

Returns the ID of the current host.

This is just a simple wrapper around jax.process_index().

create_device_mesh

create_device_mesh(devices: list[Device] | None = None) -> Mesh

Creates a Mesh object for the given devices.

Parameters:

Name Type Description Default
devices list[Device] | None

A list of devices to create a Mesh object for. If None, the default devices will be used. Defaults to None.

None

Returns:

Type Description
Mesh

jax.sharding.Mesh: A Mesh object for the given devices.

Source code in src/loadax/sharding/mesh_shape.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def create_device_mesh(
    self, devices: list[jax.Device] | None = None
) -> jax.sharding.Mesh:
    """Creates a Mesh object for the given devices.

    Args:
        devices (list[jax.Device] | None, optional): A list of devices to create a
            Mesh object for. If None, the default devices will be used. Defaults to
            None.

    Returns:
        jax.sharding.Mesh: A Mesh object for the given devices.
    """
    from loadax.sharding.mesh_utils import create_device_mesh

    return jax.sharding.Mesh(
        devices or create_device_mesh(mesh_shape=self.mesh_shape),
        tuple(self.mesh_axis_names),
    )