SampledDataset¶
A sampled dataset is a simple dataset that returns a subset of the underlying dataset. The sampling is performed lazily and does not actually sample the underlying storage. If your underlying storage does not perform well with random access, you may want to consider performing the sampling in advance. However, for almost all use cases, this should not be necessary.
The sampling procedure is deterministic and leverages JAX's random number generation.
from loadax import SampledDataset, SimpleDataset
import jax
dataset = SimpleDataset([1, 2, 3, 4, 5])
key = jax.random.PRNGKey(0)
sampled_dataset = SampledDataset(dataset, 3, key)
Bases: Dataset[Example]
, Generic[Example]
A dataset that represents a random sample of another dataset.
This dataset type allows you to create a new dataset that contains a random subset of elements from an existing dataset, specified by a sample size and a random key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset[Example]
|
The original dataset to create a sampled view of. |
required |
sample_size |
int
|
The number of samples to include in the new dataset. |
required |
key |
Array
|
The random key to use for sampling. |
required |
Source code in src/loadax/dataset/sampled_dataset.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
|