Skip to content

SampledDataset

A sampled dataset is a simple dataset that returns a subset of the underlying dataset. The sampling is performed lazily and does not actually sample the underlying storage. If your underlying storage does not perform well with random access, you may want to consider performing the sampling in advance. However, for almost all use cases, this should not be necessary.

The sampling procedure is deterministic and leverages JAX's random number generation.

Creating a sampled dataset
from loadax import SampledDataset, SimpleDataset
import jax

dataset = SimpleDataset([1, 2, 3, 4, 5])
key = jax.random.PRNGKey(0)
sampled_dataset = SampledDataset(dataset, 3, key)

Bases: Dataset[Example], Generic[Example]

A dataset that represents a random sample of another dataset.

This dataset type allows you to create a new dataset that contains a random subset of elements from an existing dataset, specified by a sample size and a random key.

Parameters:

Name Type Description Default
dataset Dataset[Example]

The original dataset to create a sampled view of.

required
sample_size int

The number of samples to include in the new dataset.

required
key Array

The random key to use for sampling.

required
Source code in src/loadax/dataset/sampled_dataset.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, dataset: Dataset[Example], sample_size: int, key: jax.Array):
    """Initialize the SampledDataset.

    Args:
        dataset: The original dataset to create a sampled view of.
        sample_size: The number of samples to include in the new dataset.
        key: The random key to use for sampling.
    """
    self.dataset = dataset
    self.sample_size = sample_size
    self.key = key

    if sample_size < 0 or sample_size > len(dataset):
        raise ValueError("Invalid sample size")

    self.indices = jax.random.choice(
        key, jnp.arange(len(dataset)), shape=(sample_size,), replace=False
    )