Skip to content

ShardedDataset

A dataset that implements the Shardable protocol. This allows you to partition your dataset across multiple hosts.

Creating a sharded dataset
from loadax import SimpleDataset, ShardedDataset

dataset = SimpleDataset([1, 2, 3, 4, 5])
sharded_dataset = ShardedDataset(dataset, num_shards=2, shard_id=0)

The shardable protocol requires you to implement the split_dataset_by_node method. This method should take in the world_size and the rank of the current host and return a shard of the dataset for that host.

Loadax's provided datasets will implement this method for you.

Bases: Dataset[Example], Generic[Example]

Divides the dataset into non-overlapping contiguous shards.

Parameters:

Name Type Description Default
dataset Dataset[E]

The underlying dataset to shard.

required
num_shards int

Total number of shards.

required
shard_id int

The ID of the current shard (0-based).

required
drop_remainder bool

Whether to drop the last incomplete shard. Defaults to True.

True

Raises:

Type Description
TypeError

If dataset is not an instance of Dataset.

ValueError

If num_shards is not a positive integer.

ValueError

If shard_id is not in the range [0, num_shards).

ValueError

If drop_remainder is True and dataset_size < num_shards.

Source code in src/loadax/dataset/sharded_dataset.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    dataset: Dataset[Example],
    num_shards: int,
    shard_id: int,
    *,
    drop_remainder: bool = True,
):
    """Initialize a ShardedDataset to shard the given dataset.

    Args:
        dataset (Dataset[E]): The underlying dataset to shard.
        num_shards (int): Total number of shards.
        shard_id (int): The ID of the current shard (0-based).
        drop_remainder (bool, optional): Whether to drop the last incomplete shard.
            Defaults to True.

    Raises:
        TypeError: If `dataset` is not an instance of Dataset.
        ValueError: If `num_shards` is not a positive integer.
        ValueError: If `shard_id` is not in the range [0, num_shards).
        ValueError: If `drop_remainder` is True and `dataset_size` < `num_shards`.
    """
    if not isinstance(dataset, Shardable):
        raise TypeError("dataset must implement the Shardable protocol.")
    if not isinstance(num_shards, int) or num_shards <= 0:
        raise ValueError("num_shards must be a positive integer.")
    if not isinstance(shard_id, int) or not (0 <= shard_id < num_shards):
        raise ValueError(f"shard_id must be an integer in [0, {num_shards}).")

    self.dataset = dataset
    self.num_shards = num_shards
    self.shard_id = shard_id
    self.drop_remainder = drop_remainder
    self.dataset_size = len(self.dataset)

    if self.drop_remainder and self.dataset_size < self.num_shards:
        raise ValueError(
            f"dataset_size ({self.dataset_size}) must be >= num_shards "
            f"({self.num_shards}) when drop_remainder is True."
        )

    self.start, self.end = compute_shard_boundaries(
        num_shards=self.num_shards,
        shard_id=self.shard_id,
        dataset_size=self.dataset_size,
        drop_remainder=self.drop_remainder,
    )

    self._length = max(0, self.end - self.start)

shard_boundaries

shard_boundaries() -> tuple[int, int]

Return the start and end boundaries of the shard.

Returns:

Type Description
tuple[int, int]

Tuple[int, int]: The (start, end) indices of the shard.

Source code in src/loadax/dataset/sharded_dataset.py
175
176
177
178
179
180
181
def shard_boundaries(self) -> tuple[int, int]:
    """Return the start and end boundaries of the shard.

    Returns:
        Tuple[int, int]: The (start, end) indices of the shard.
    """
    return self.start, self.end

Bases: Protocol, Generic[Example]

A shardable dataset must implement the Shardable protocol.

Each dataset has to implement sharding by itself because the underlying storage may have unique constraints to consider when creating the sharding boundaries.

split_dataset_by_node abstractmethod

split_dataset_by_node(world_size: int, rank: int) -> Dataset[Example]

Split the dataset into shards.

If possible the shards should be of equal size and non-overlapping and continguous.

Parameters:

Name Type Description Default
world_size int

The number of nodes.

required
rank int

The rank of the current node.

required

Returns:

Type Description
Dataset[Example]

Dataset[Example]: The shard of the dataset for the current node.

Source code in src/loadax/dataset/sharded_dataset.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@abstractmethod
def split_dataset_by_node(self, world_size: int, rank: int) -> Dataset[Example]:
    """Split the dataset into shards.

    If possible the shards should be of equal size and non-overlapping
    and continguous.

    Args:
        world_size (int): The number of nodes.
        rank (int): The rank of the current node.

    Returns:
        Dataset[Example]: The shard of the dataset for the current node.
    """
    pass