Skip to content

MappedDataset

A mapped dataset is a transformation that applies a functional transformation to the underlying dataset. The transformation is applied lazily, the underlying dataset is preserved. Because loadax performs dataloading in the background, this means it is acceptable to perform lightweight data augmentation or transformations on the dataset.

If you have some complicated transformations you may still want to perform them ahead of time.

Creating a mapped dataset
from loadax import MappedDataset, SimpleDataset

def transform(x):
    return x * 2

dataset = MappedDataset(SimpleDataset([1, 2, 3, 4, 5]), transform)

for i in range(len(dataset)):
    print(dataset.get(i))

#> 2
#> 4
#> 6
#> 8
#> 10

Bases: Dataset[Transformed], Generic[Example, Transformed]

A dataset that applies a transformation to each element in the dataset.

The transformation is lazily applied, this means that the underlying data is not altered and instead is only applied when iterated over.

Parameters:

Name Type Description Default
dataset Dataset[Example]

The underlying dataset to apply the transformation to.

required
transform Callable[[Example], Transformed]

The transformation to apply to each element in the dataset.

required
Source code in src/loadax/dataset/dataset.py
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self, dataset: Dataset[Example], transform: Callable[[Example], Transformed]
):
    """Intializes the MappedDataset.

    Args:
        dataset: The underlying dataset to apply the transformation to.
        transform: The transformation to apply to each element in the dataset.
    """
    self.dataset = dataset
    self.transform = transform

MappedBatchDataset

A mapped batch dataset is a transformation that applies a functional transformation to the underlying dataset. The transformation is applied lazily, the underlying dataset is preserved. Because loadax performs dataloading in the background, this means it is acceptable to perform lightweight data augmentation or transformations on the dataset.

Similar to the MappedDataset, but the transformation is applied to batches of items instead of individual items. This is useful for performing batch-level transformations such as data augmentation or working with more expensive transformations that can be vectorized.

Creating a mapped batch dataset
from loadax import MappedBatchDataset, SimpleDataset

def transform(batch):
    return [item * 2 for item in batch]

dataset = MappedBatchDataset(SimpleDataset([1, 2, 3, 4, 5]), transform)

for i in range(len(dataset)):
    print(dataset.get(i))

#> 2
#> 4
#> 6
#> 8
#> 10

Bases: Dataset[Transformed], Generic[Example, Transformed]

Performs element transformations in batches.

Just as with MappedDataset, the transformation is lazily applied, this means that the underlying data is not altered and instead is only applied when iterated over.

Batched-mapping is useful when you want to apply a transformation that is particularly expensive to apply to a large number of elements. For example, if you have a dataset that needs to be tokenized, you can apply the tokenization to each batch of elements in the dataset to avoid the overhead of tokenizing each element individually.

Parameters:

Name Type Description Default
dataset Dataset[Example]

The underlying dataset to apply the transformation to.

required
transform Callable[[list[Example]], Transformed]

The transformation to apply to each batch of elements in the dataset.

required
batch_size int

The size of each batch.

32
Source code in src/loadax/dataset/dataset.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def __init__(
    self,
    dataset: Dataset[Example],
    transform: Callable[[list[Example]], Transformed],
    batch_size: int = 32,
):
    """Intializes the MappedBatchDataset.

    Args:
        dataset: The underlying dataset to apply the transformation to.
        transform: The transformation to apply to each batch of elements in the
            dataset.
        batch_size: The size of each batch.
    """
    self.dataset = dataset
    self.transform = transform
    self.batch_size = batch_size