Loadax¶
Loadax is a dataloading library designed for the JAX ecosystem. It provides utilities for feeding data into your training loop without having to worry about batching, shuffling, and other preprocessing steps. Loadax also handles background prefetching to improve performance, and distriubted data loading to train on multiple devices and even multiple hosts.
Loadax Example
from loadax import Dataloader, SimpleDataset
dataset = SimpleDataset([1, 2, 3, 4, 5])
dataloader = Dataloader(dataset, batch_size=2)
for batch in loader:
print(batch)
#> [1, 2]
#> [3, 4]
#> [5]
Installation¶
uv add loadax