Skip to content

Simple

The Simple is a simple dataset that stores all underlying items in a list in memory. This is a simple dataset and is only useful for small datasets and debugging.

Creating an in-memory dataset
from loadax import Simple

dataset = Simple([1, 2, 3, 4, 5])

for i in range(len(dataset)):
    print(dataset.get(i))

#> 1
#> 2
#> 3
#> 4
#> 5

Bases: Shardable[Example], Shuffleable[Example], Dataset[Example], Generic[Example]

A dataset that wraps a list of examples.

Parameters:

Name Type Description Default
data List[Example]

The list of data examples.

required

Parameters:

Name Type Description Default
data List[Example]

The list of data examples.

required
Source code in src/loadax/dataset/simple.py
23
24
25
26
27
28
29
def __init__(self, data: list[Example]):
    """Initialize a simple dataset in-memory from a list.

    Args:
        data (List[Example]): The list of data examples.
    """
    self.data = data

split_dataset_by_node

split_dataset_by_node(world_size: int, rank: int) -> Dataset[Example]

Split the dataset into shards.

Parameters:

Name Type Description Default
world_size int

The number of nodes.

required
rank int

The rank of the current node.

required

Returns:

Type Description
Dataset[Example]

Dataset[Example]: The shard of the dataset for the current node.

Source code in src/loadax/dataset/simple.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def split_dataset_by_node(self, world_size: int, rank: int) -> Dataset[Example]:
    """Split the dataset into shards.

    Args:
        world_size (int): The number of nodes.
        rank (int): The rank of the current node.

    Returns:
        Dataset[Example]: The shard of the dataset for the current node.
    """
    start, end = compute_shard_boundaries(
        num_shards=world_size,
        shard_id=rank,
        dataset_size=len(self),
        drop_remainder=False,
    )
    return SimpleDataset(self.data[start:end])

shuffle

shuffle(seed: Array) -> Dataset[Example]

Shuffle the dataset.

Parameters:

Name Type Description Default
seed Array

The seed to use for the shuffle. This is a jax PRNGKey as all randomization in loadax is implemented using jax.random.

required

Returns:

Type Description
Dataset[Example]

The shuffled dataset.

Source code in src/loadax/dataset/simple.py
69
70
71
72
73
74
75
76
77
78
79
80
def shuffle(self, seed: jax.Array) -> "Dataset[Example]":
    """Shuffle the dataset.

    Args:
        seed: The seed to use for the shuffle. This is a jax
            PRNGKey as all randomization in loadax is implemented using jax.random.

    Returns:
        The shuffled dataset.
    """
    indices = jax.random.permutation(seed, len(self))
    return SimpleDataset([self.data[i] for i in indices])