Skip to content

Loadax

GitHub Actions Workflow Status PyPI - Status PyPI - Downloads GitHub License

Loadax is a dataloading library designed for the JAX ecosystem. It provides utilities for feeding data into your training loop without having to worry about batching, shuffling, and other preprocessing steps. Loadax also handles background prefetching to improve performance, and distriubted data loading to train on multiple devices and even multiple hosts.

Loadax Example
from loadax import Dataloader, SimpleDataset

dataset = SimpleDataset([1, 2, 3, 4, 5])
dataloader = Dataloader(dataset, batch_size=2)

for batch in loader:
    print(batch)

#> [1, 2]
#> [3, 4]
#> [5]

Installation

uv add loadax