Placement Utilities¶
Loadax provides a few utilities to help with sharding and placement of data and models.
First we should cover why these utilities are necessary.
In a distributed setting jax wants to know about the placement of regions of data called "shards". Shards are ranges within an array that fit on a single device. When training on multiple devices, and especially when using multiple nodes, you want to simplify the synchronization and allow jax to handle as much as possible. A great way to do this is to have each node load a unique subset of your dataset. This will make every batch unique on each device. In a single node setup, this is a no-op and you would not need to shard your dataset. If you are using multiple nodes you then have the challenge of synchronizing your training such that each node trains on a unique subset of the data and the nodes coordinate the backpropagation and gradient updates.
This is where Loadax's host_to_global_device_array
comes in. This function will take an array and communicate with all other nodes in the network to create a "global" array that is the same across all devices. This is different than sharding the array because the data never actually moves. This function annotates the array with the placement of the data so that jax can treat each node's batch as a single larger global batch with the local batches stitched together.
host_to_global_device_array¶
This function takes an array and annotates it with the placement of the data so that jax can treat each node's batch as a single larger global batch with the local batches stitched together.
Converts the given host device arrays to global device arrays.
Must be called within the context of a Mesh.
We cannot use multihost_utils.host_local_array_to_global_array
since the local
mesh may not be contiguous. According to yashkatariya@google.com,
"using jax.make_array_from_single_device_arrays
is the right solution."
Parameters:
Name | Type | Description | Default |
---|---|---|---|
host_arrays |
Nested[Array]
|
a nested tree of device arrays in host memory. Usually these present the per-host portion of the global input batch. |
required |
partition |
DataPartitionType
|
how the global array should be partitioned. |
FULL
|
Returns:
Type | Description |
---|---|
Nested[Array]
|
A nested tree with the same structure as |
Nested[Array]
|
arrays at the leaves. Each global device array is partitioned |
Nested[Array]
|
according to |
Raises:
Type | Description |
---|---|
NotImplementedError
|
if the given |
Source code in src/loadax/sharding/placement.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|
global_to_host_array¶
The inverse of host_to_global_device_array
. This function takes a global array and splits it into the local arrays for each node.
Extracts host addressable rows from each Array in global_arrays
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
global_arrays |
Nested[Array]
|
A Nested[jax.Array]. Each leaf Array must have shape [global_batch_size, ...] with identical global_batch_size across arrays. The arrays must be partitioned in the same way and can be partitioned only along the batch axis. |
required |
partition |
DataPartitionType
|
How the global array should be partitioned. |
FULL
|
Returns:
Type | Description |
---|---|
Nested[Array]
|
A Nested[jax.Array] with the same structure as |
Nested[Array]
|
Array will have shape [host_batch_size, ...] where |
Nested[Array]
|
equal to |
Nested[Array]
|
|
Nested[Array]
|
across hosts. |
Source code in src/loadax/sharding/placement.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
|
with_sharding_constraint¶
This is syntactic sugar that ensures a with_sharding_constraint
is applied when inside a Mesh context.
Syntax sugar for jax.lax.with_sharding_constraint
.
Used from within the context of a Mesh, this will produce a no-op if the Mesh is empty or has only one device.
Source code in src/loadax/sharding/placement.py
175 176 177 178 179 180 181 182 183 184 185 |
|