Mesh¶
Loadax has a specialized mesh definition called a HybridMeshShape
. This enables you to define your mesh with two contexts, the global topology (inter-node) and the local topology (intra-node). This makes it easier to reason about the placement of data across the network.
To use loadax's mesh abstractions, you need to define a MeshConfig
. This config tells loadax how to split up the work across the mesh. A MeshConfig
must specify a mesh_shape
, which is a HybridMeshShape
and some annotations about the mesh axes.
Typically you should not need to define a MeshConfig
yourself. Instead you can rely on Loadax's automatic mesh discovery to find a good mesh shape for your cluster for a given parallelization strategy. See the Presets section for more details.
A mesh shape for hybrid (i.e., ICI and DCN) parallelism.
For example, with mesh axes (data, model): - Pure fsdp on a v4-8:
HybridMeshShape(ici_mesh_shape=(1, 4), dcn_mesh_shape=(1, 1))
HybridMeshShape(ici_mesh_shape=(1, 8), dcn_mesh_shape=(2, 1))
ici_mesh_shape
instance-attribute
¶
ici_mesh_shape: MeshShape
The mesh shape for the ICI (inner-core) parallelism. This represents the host-local sharding.
dcn_mesh_shape
instance-attribute
¶
dcn_mesh_shape: MeshShape
The mesh shape for the DCN (data-parallel) parallelism. This represents the inter-host sharding.
Sharding Mesh Configuration.
mesh_shape
instance-attribute
¶
mesh_shape: MeshShape | HybridMeshShape
If specified as a MeshShape, must have the same length as mesh_axis_names. Implicitly, this treats the mesh shape as the ICI mesh shape; we default to a DCN mesh shape that partitions the first non-singleton axis across granules (e.g. TPU slices or GPU nodes). If all axes are singletons, this implies a single-granule environment and therefore an all-1's DCN mesh shape.
As an example on 2 H100 nodes, for mesh axes (pipeline, data, model) and a MeshShape of (1, 2, 8), we break the "data" axis across DCN -- this produces a DCN mesh shape (1, 2, 1) and an ICI mesh shape (1, 1, 8), i.e. 2-way data-parallelism across DCN, and 8-way model parallelism within-node (e.g. NVLink). If instead the MeshShape is provided as (2, 1, 8), we break along the "pipeline" axis, producing a DCN mesh shape of (2, 1, 1) and ICI mesh shape (1, 1, 8) for 2-way pipeline-parallelism across DCN and 8-way model parallelism within-node.
If specified as a HybridMeshShape, each member must have the same length as mesh_axis_names.
Use mesh_rules
to set different mesh shapes depending on the hardware platform.
mesh_axis_names
instance-attribute
¶
The mesh axis names. The names can be referenced in ParameterSpec.mesh_axes.
batch_axis_names
class-attribute
instance-attribute
¶
Subset of mesh axis names over which leaves of the input batch are sharded.
mesh_rules
class-attribute
instance-attribute
¶
An optional list of (regex, MeshShape) pairs to override the default mesh configuration.
This is useful when we want to use different mesh shapes depending on the device types (e.g., 'tpu-v4-128' vs. 'gpu-p4de.24xlarge-32').
Given a mesh_selector
string (usually representing the device type and set by
user's launch script), the first rule that with a regex that matches the selector
will determine the mesh shape.
If no rule matches, the default mesh configuration will be used.
hosts
property
¶
hosts: int
Returns the number of hosts in the mesh.
This is just a simple wrapper around jax.process_count().
host_id
property
¶
host_id: int
Returns the ID of the current host.
This is just a simple wrapper around jax.process_index().
create_device_mesh ¶
create_device_mesh(devices: list[Device] | None = None) -> Mesh
Creates a Mesh object for the given devices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
devices |
list[Device] | None
|
A list of devices to create a Mesh object for. If None, the default devices will be used. Defaults to None. |
None
|
Returns:
Type | Description |
---|---|
Mesh
|
jax.sharding.Mesh: A Mesh object for the given devices. |
Source code in src/loadax/sharding/mesh_shape.py
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
|