ShardedDataset¶
A dataset that implements the Shardable
protocol. This allows you to partition your dataset across multiple hosts.
from loadax import SimpleDataset, ShardedDataset
dataset = SimpleDataset([1, 2, 3, 4, 5])
sharded_dataset = ShardedDataset(dataset, num_shards=2, shard_id=0)
The shardable protocol requires you to implement the split_dataset_by_node
method. This method should take in the world_size
and the rank
of the current host and return a shard of the dataset for that host.
Loadax's provided datasets will implement this method for you.
Bases: Dataset[Example]
, Generic[Example]
Divides the dataset into non-overlapping contiguous shards.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset[E]
|
The underlying dataset to shard. |
required |
num_shards |
int
|
Total number of shards. |
required |
shard_id |
int
|
The ID of the current shard (0-based). |
required |
drop_remainder |
bool
|
Whether to drop the last incomplete shard. Defaults to True. |
True
|
Raises:
Type | Description |
---|---|
TypeError
|
If |
ValueError
|
If |
ValueError
|
If |
ValueError
|
If |
Source code in src/loadax/dataset/sharded_dataset.py
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
|
shard_boundaries ¶
Return the start and end boundaries of the shard.
Returns:
Type | Description |
---|---|
tuple[int, int]
|
Tuple[int, int]: The (start, end) indices of the shard. |
Source code in src/loadax/dataset/sharded_dataset.py
175 176 177 178 179 180 181 |
|
Bases: Protocol
, Generic[Example]
A shardable dataset must implement the Shardable protocol.
Each dataset has to implement sharding by itself because the underlying storage may have unique constraints to consider when creating the sharding boundaries.
split_dataset_by_node
abstractmethod
¶
Split the dataset into shards.
If possible the shards should be of equal size and non-overlapping and continguous.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
world_size |
int
|
The number of nodes. |
required |
rank |
int
|
The rank of the current node. |
required |
Returns:
Type | Description |
---|---|
Dataset[Example]
|
Dataset[Example]: The shard of the dataset for the current node. |
Source code in src/loadax/dataset/sharded_dataset.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
|