Skip to content

Placement Utilities

Loadax provides a few utilities to help with sharding and placement of data and models.

First we should cover why these utilities are necessary.

In a distributed setting jax wants to know about the placement of regions of data called "shards". Shards are ranges within an array that fit on a single device. When training on multiple devices, and especially when using multiple nodes, you want to simplify the synchronization and allow jax to handle as much as possible. A great way to do this is to have each node load a unique subset of your dataset. This will make every batch unique on each device. In a single node setup, this is a no-op and you would not need to shard your dataset. If you are using multiple nodes you then have the challenge of synchronizing your training such that each node trains on a unique subset of the data and the nodes coordinate the backpropagation and gradient updates.

This is where Loadax's host_to_global_device_array comes in. This function will take an array and communicate with all other nodes in the network to create a "global" array that is the same across all devices. This is different than sharding the array because the data never actually moves. This function annotates the array with the placement of the data so that jax can treat each node's batch as a single larger global batch with the local batches stitched together.

host_to_global_device_array

This function takes an array and annotates it with the placement of the data so that jax can treat each node's batch as a single larger global batch with the local batches stitched together.

Converts the given host device arrays to global device arrays.

Must be called within the context of a Mesh.

We cannot use multihost_utils.host_local_array_to_global_array since the local mesh may not be contiguous. According to yashkatariya@google.com, "using jax.make_array_from_single_device_arrays is the right solution."

Parameters:

Name Type Description Default
host_arrays Nested[Array]

a nested tree of device arrays in host memory. Usually these present the per-host portion of the global input batch.

required
partition DataPartitionType

how the global array should be partitioned.

FULL

Returns:

Type Description
Nested[Array]

A nested tree with the same structure as host_arrays, but global device

Nested[Array]

arrays at the leaves. Each global device array is partitioned

Nested[Array]

according to partition.

Raises:

Type Description
NotImplementedError

if the given partition type is not supported.

Source code in src/loadax/sharding/placement.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def host_to_global_device_array(
    host_arrays: Nested[jax.Array],
    *,
    partition: DataPartitionType = DataPartitionType.FULL,
) -> Nested[jax.Array]:
    """Converts the given host device arrays to global device arrays.

    Must be called within the context of a Mesh.

    We cannot use `multihost_utils.host_local_array_to_global_array` since the local
    mesh may not be contiguous. According to yashkatariya@google.com,
    "using `jax.make_array_from_single_device_arrays` is the right solution."

    Args:
        host_arrays: a nested tree of device arrays in host memory. Usually these
            present the per-host portion of the global input batch.
        partition: how the global array should be partitioned.

    Returns:
        A nested tree with the same structure as `host_arrays`, but global device
        arrays at the leaves. Each global device array is partitioned
        according to `partition`.

    Raises:
        NotImplementedError: if the given `partition` type is not supported.
    """
    mesh = thread_resources.env.physical_mesh
    partition_spec = data_partition_type_to_spec(partition)

    local_devices = mesh.local_devices

    def put_to_devices_fully_partitioned(x: jax.Array) -> list[jax.Array]:
        len_local_devices = len(local_devices)
        if x.shape[0] % len_local_devices != 0:
            raise ValueError(
                f"({x.shape}) cannot be sharded across {len_local_devices} devices."
            )
        # np.reshape is faster than np.split, jnp.reshape, and jnp.split.
        xs = np.reshape(
            x, (len_local_devices, x.shape[0] // len_local_devices, *x.shape[1:])
        )
        return [
            jax.device_put(x_i, device)
            for x_i, device in zip(xs, local_devices, strict=False)
        ]

    def put_to_devices_replicated(x: jax.Array) -> list[jax.Array]:
        # Replicate `x` to every local device.
        return [jax.device_put(x, device) for device in local_devices]

    if partition == DataPartitionType.FULL:
        put_to_devices = put_to_devices_fully_partitioned
    elif partition == DataPartitionType.REPLICATED:
        put_to_devices = put_to_devices_replicated
    else:
        raise NotImplementedError(f"Unsupported partition: {partition}")

    device_arrays = jax.tree_util.tree_map(put_to_devices, host_arrays)
    partition_specs = complete_partition_spec_tree(
        jax.tree_util.tree_structure(host_arrays),
        partition_spec,
    )

    def make_gda(
        x: jax.Array, device_buffers: list[jax.Array], partition_spec: PartitionSpec
    ) -> jax.Array:
        if partition == DataPartitionType.FULL:
            global_batch_size = x.shape[0] * jax.process_count()
        elif partition == DataPartitionType.REPLICATED:
            global_batch_size = x.shape[0]
        else:
            raise NotImplementedError(f"Unsupported partition: {partition}")
        global_shape = (global_batch_size, *list(x.shape[1:]))
        return jax.make_array_from_single_device_arrays(
            shape=global_shape,
            sharding=jax.sharding.NamedSharding(mesh, partition_spec),
            arrays=device_buffers,
        )

    return jax.tree_util.tree_map(make_gda, host_arrays, device_arrays, partition_specs)  # type: ignore

global_to_host_array

The inverse of host_to_global_device_array. This function takes a global array and splits it into the local arrays for each node.

Extracts host addressable rows from each Array in global_arrays.

Parameters:

Name Type Description Default
global_arrays Nested[Array]

A Nested[jax.Array]. Each leaf Array must have shape [global_batch_size, ...] with identical global_batch_size across arrays. The arrays must be partitioned in the same way and can be partitioned only along the batch axis.

required
partition DataPartitionType

How the global array should be partitioned.

FULL

Returns:

Type Description
Nested[Array]

A Nested[jax.Array] with the same structure as global_array. Each leaf

Nested[Array]

Array will have shape [host_batch_size, ...] where host_batch_size will be

Nested[Array]

equal to global_batch_size if the global Arrays are replicated or

Nested[Array]

global_batch_size // process_count if the global Arrays are partitioned

Nested[Array]

across hosts.

Source code in src/loadax/sharding/placement.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def global_to_host_array(
    global_arrays: Nested[jax.Array],
    *,
    partition: DataPartitionType = DataPartitionType.FULL,
) -> Nested[jax.Array]:
    """Extracts host addressable rows from each Array in `global_arrays`.

    Args:
        global_arrays: A Nested[jax.Array].
            Each leaf Array must have shape [global_batch_size, ...] with identical
            global_batch_size across arrays.
            The arrays must be partitioned in the same way and can be partitioned
            only along the batch axis.
        partition: How the global array should be partitioned.

    Returns:
        A Nested[jax.Array] with the same structure as `global_array`. Each leaf
        Array will have shape [host_batch_size, ...] where `host_batch_size` will be
        equal to `global_batch_size` if the global Arrays are replicated or
        `global_batch_size // process_count` if the global Arrays are partitioned
        across hosts.
    """

    def sort_global_shards(global_shards: list[jax.Shard]) -> list[jax.Shard]:
        # We should sort jax.Array.global_shards by using this function to guarantee
        # round-trip equality of host_to_global_device_array and global_to_host_array.
        # Shards are sorted in-place.
        global_shards.sort(key=lambda shard: shard.index)
        return global_shards

    global_array_items = flatten_items(global_arrays)
    if not global_array_items:
        return global_arrays  # no leaf jax.Array.
    first_path, first_value = global_array_items[0]
    sorted_first_value_shards = sort_global_shards(list(first_value.global_shards))
    first_value_shard_is_local = [
        shard.data is not None for shard in sorted_first_value_shards
    ]
    batch_size = first_value.shape[0]

    def get_local_array(path: str, value: jax.Array) -> jax.Array:
        if value.shape[0] != batch_size:
            raise ValueError(
                f"Value batch size mismatch: {batch_size} @ {first_path} vs. "
                f"{value.shape[0]} @ {path} of {shapes(global_arrays)}"
            )
        sorted_value_shards = sort_global_shards(list(value.global_shards))
        value_shard_is_local = [shard.data is not None for shard in sorted_value_shards]
        if value_shard_is_local != first_value_shard_is_local:
            raise ValueError(
                f"Value shard mismatch: {first_value_shard_is_local} @ {first_path} "
                f"vs. {value_shard_is_local} @ {path}"
            )
        local_data = [
            shard.data for shard in sorted_value_shards if shard.data is not None
        ]
        if not local_data:
            raise ValueError(f"No local shard found: {sorted_value_shards}.")
        if partition == DataPartitionType.FULL:
            # return ndarray its faster than jnp.concatenate
            return np.concatenate(local_data, axis=0)  # type: ignore
        elif partition == DataPartitionType.REPLICATED:
            return local_data[0]  # type: ignore
        else:
            raise NotImplementedError(f"Unsupported partition: {partition}")

    # TODO: jtu types are bad
    return jax.tree_util.tree_map(  # type: ignore
        get_local_array, tree_paths(global_arrays), global_arrays
    )

with_sharding_constraint

This is syntactic sugar that ensures a with_sharding_constraint is applied when inside a Mesh context.

Syntax sugar for jax.lax.with_sharding_constraint.

Used from within the context of a Mesh, this will produce a no-op if the Mesh is empty or has only one device.

Source code in src/loadax/sharding/placement.py
175
176
177
178
179
180
181
182
183
184
185
def with_sharding_constraint(x: jax.Array, shardings: Any) -> jax.Array:
    """Syntax sugar for `jax.lax.with_sharding_constraint`.

    Used from within the context of a Mesh, this will produce a no-op if the Mesh
    is empty or has only one device.
    """
    mesh = thread_resources.env.physical_mesh
    if mesh.empty or mesh.size == 1:
        return x
    # TODO: jax types are bad
    return jax.lax.with_sharding_constraint(x, shardings)  # type: ignore