Skip to content

PartialDataset

A partial dataset is a simple dataset that returns a subset of the underlying dataset. This is useful for testing and debugging.

Creating a partial dataset
from loadax import PartialDataset, SimpleDataset

dataset = SimpleDataset([1, 2, 3, 4, 5])
partial_dataset = PartialDataset(dataset, 2, 4)

for i in range(len(partial_dataset)):
    print(partial_dataset.get(i))

#> 2
#> 3
#> 4

Bases: Dataset[Example], Generic[Example]

A dataset that represents a range of another dataset.

This dataset type allows you to create a new dataset that contains a subset of elements from an existing dataset, specified by a start and end index.

Parameters:

Name Type Description Default
dataset Dataset[Example]

The original dataset to create a partial view of.

required
start int

The starting index of the range (inclusive).

required
end int

The ending index of the range (exclusive).

required
Source code in src/loadax/dataset/partial_dataset.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(self, dataset: Dataset[Example], start: int, end: int):
    """Initialize the PartialDataset.

    Args:
        dataset: The original dataset to create a partial view of.
        start: The starting index of the range (inclusive).
        end: The ending index of the range (exclusive).
    """
    self.dataset = dataset
    self.start = start
    self.end = end

    if start < 0 or end > len(dataset) or start >= end:
        raise ValueError("Invalid start or end index")

split_dataset staticmethod

split_dataset(dataset: Dataset[Example], num_partitions: int) -> list[PartialDataset[Example]]

Split a dataset into a number of partial datasets.

Parameters:

Name Type Description Default
dataset Dataset[Example]

The original dataset to split.

required
num_partitions int

The number of partitions to create.

required

Returns:

Type Description
list[PartialDataset[Example]]

A list of PartialDataset objects.

Raises:

Type Description
ValueError

If num_partitions is less than 1 or greater than the dataset size.

Source code in src/loadax/dataset/partial_dataset.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@staticmethod
def split_dataset(
    dataset: Dataset[Example], num_partitions: int
) -> list["PartialDataset[Example]"]:
    """Split a dataset into a number of partial datasets.

    Args:
        dataset: The original dataset to split.
        num_partitions: The number of partitions to create.

    Returns:
        A list of PartialDataset objects.

    Raises:
        ValueError: If num_partitions is less than 1 or greater than the
            dataset size.
    """
    if num_partitions < 1:
        raise ValueError("Number of partitions must be at least 1")
    if num_partitions > len(dataset):
        raise ValueError("Number of partitions cannot exceed dataset size")

    partition_size = len(dataset) // num_partitions
    remainder = len(dataset) % num_partitions

    partials = []
    start = 0
    for i in range(num_partitions):
        end = start + partition_size + (1 if i < remainder else 0)
        partials.append(PartialDataset(dataset, start, end))
        start = end

    return partials