Adding a dataset

To add your own dataset or integrate datasets from other libraries, you can create new component that extend the Dataset class.

This interface specifies abstract methods that need to be implemented:

@frozen
class Dataset(ABC, Component):
    """Abstract base class for datasets."""

    @abstractmethod
    def get_dataloader(self, b: int, split: Split, num_workers: int) -> DataLoader:
        """Return a dataloader with batch size ``b`` for the training or validation dataset."""
        raise NotImplementedError()

    @abstractmethod
    def data_input_shape(self, batch_size: int) -> torch.Size:
        """The expected input shape of the data, including the batch size"""
        raise NotImplementedError()

    @abstractmethod
    def model_output_shape(self, batch_size: int) -> torch.Size:
        """The expected output shape of the data, the shape of the targets"""
        raise NotImplementedError()

    @abstractmethod
    def get_num_samples(self, split: Split) -> int:
        """The number of samples in the training or validation sets."""
        raise NotImplementedError()

    def has_validation_set(self) -> bool:
        return True

    @abstractmethod
    def has_test_set(self) -> bool:
        """Does the dataset have a test set."""
        raise NotImplementedError()

As an example, we’ll follow the current implementation for MNIST.

Basic information about the dataset

The following 3 methods just return basic statistics about the dataset:

  • input_shape

  • output_shape

  • get_num_samples

        @abstractmethod
        def get_num_samples(self, split: Split) -> int:
            """The number of samples in the training or validation sets."""
            raise NotImplementedError()
    

Making dataloaders

The last method, get_dataloader, returns the actual torch.utils.data.DataLoader that will be used to iterate over the dataset.

    @abstractmethod
    def get_dataloader(self, b: int, split: Split, num_workers: int) -> DataLoader:
        """Return a dataloader with batch size ``b`` for the training or validation dataset."""
        raise NotImplementedError()
def make_dataloader(dataset, b, num_workers):
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=b,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=True,
    )

Adding further integrations

Additional features can be implemented by inheriting from the following classes. For example, to define a dataset that can be downloaded from the internet when using the prepare CLI option, extend both Dataset and Downloadable and implement their respective abstract methods;

class MyDataset(Dataset, Downloadable):
    ...

The Downloadable extension is useful, the other are very much optional.

Downloadble datasets

class Downloadable:
    """Extension for datasets that can be downloaded.

    For interaction with the ``prepare`` CLI command.
    """

    @abstractmethod
    def is_downloaded(self) -> bool:
        """Checks whether the dataset is already in the workspace."""
        raise NotImplementedError()

    @abstractmethod
    def download(self):
        """Downloads the dataset to the workspace."""
        raise NotImplementedError()

Datasets with class counts

class HasClassCounts:
    """Extension for datasets that provide class frequencies."""

    @abstractmethod
    def class_counts(self, split: Split) -> torch.Tensor:
        raise NotImplementedError()

Datasets that can be put in RAM

class InMemory:
    """Extension for small datasets that can be loaded directly into RAM."""

    @abstractmethod
    def get_in_memory_dataloader(
        self,
        b: int,
        split: Split,
        num_workers: int,
        to_device: Optional[Device] = None,
    ) -> torch.utils.data.DataLoader:
        """Returns a Dataloader with the dataset already loaded into RAM on the device."""
        raise NotImplementedError()

Datasets that need to be moved to local storage

class MovableToLocal:
    """Extension for large datasets that need to be moved to local storage on SLURM nodes."""

    @abstractmethod
    def is_on_local(self) -> bool:
        """Returns True if the dataset is already in local storage."""
        raise NotImplementedError()

    @abstractmethod
    def move_to_local(self):
        """Moves the dataset from the workspace to local storage."""
        raise NotImplementedError()

Not implemented by MNIST as the dataset is small.

Full MNIST Implementation